模块化函数(3) utils

1. 打开 txt 文件,构建词典

    with open('../data/OCR-data/test.txt') as f:
        data = f.readlines()
        alphabet=[x.rstrip() for x in data]
        alphabet=''.join(alphabet)
第一步:'我\n', '爱\n'
第二步:'我', '爱'
第三步:'我爱'

2. 构建映射encode-decode

首先明白.encode()和.decode()的type变化

    a="我爱你啊"
    a=a.encode()
    #<class 'bytes'>
    b=a.decode()
    #<class 'str'>

构建映射表

class strLabelConverterForAttention(object):
    def __init__(self,alphabet):
        #构建dict映射
        self.alphabet=alphabet
        self.dict={}
        self.dict['SOS']=0#开始
        self.dict['EOS']=1#结束
        self.dict['$']=2#blank标识符
        for i,item in enumerate(self.alphabet):
            self.dict[item]=i+3
'SOS':'0','EOS':'1','$':'2',....

2.1 Attention_encode

    def encode(self,text):
        if isinstance(text,str):
            text=[self.dict[item] for item in text]
        elif isinstance(text,collections.Iterable):
            text=[self.encode(s) for s in text]    # 编码
            max_length=max([len(x) for x in text]) # 对齐
            nb=len(text)
            #在max_length和nb之间use ‘blank’ for padding
            targets=torch.ones(nb,max_length+2)*2
            for i in range(nb):
                targets[i][0]=0
                targets[i][1:len(text[i])+1]=text[i]
                targets[i][len(text[i])+1]=1
            text=targets.transpose(0,1).contiguous()
            text=text.long()
        return torch.LongTensor(text)

2.2 Attention_decode

    def decode(self,t):
        texts=list(self.dict.keys())[list(self.dict.values()).index(t)]
        return texts
#单个单个decode
t=torch.FloatTensor([1])
print(list(self.dict.keys())[list(self.dict.values()).index(t)])
#SOS

2.3 CTC_encode

2.4 CTC_decode

3. 计算average

class average(object):
    def __init__(self):
        self.reset()
    def add(self,v):
        if isinstance(v,Variable):
            count=v.data.numel()
            v=v.data.sum()
        elif isinstance(v,torch.Tensor):
            count=v.numel()
            v=v.sum()
        self.n_count+=count
        self.sum+=v
    def reset(self):
        self.n_count=0
        self.sum=0
    def val(self):
        res=0
        if self.n_count!=0:
            res=self.sum/float(self.n_count)
        return res

4. one_hot编码

def onehot(v,v_length,nc):
    #v_ength:[batch_size,len]
    #[[2],[3],[4]]
    batch_size=v_length.size(0)
    maxlength=v_length.max()#[4],若nc=4
    v_onehot=torch.FloatTensor(batch_size,maxlength,nc).fill_(0)
    acc=0
    for i in range(batch_size):
        length=v_length[i]#2
        label=v[acc:acc+length].view(-1,1).long()#[0:2]
        #print(v_onehot[i,:length,:])
        v_onehot[i,:length].scatter_(1,label,1.0)
        acc+=length
    return v_onehot
"""tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])
   tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])"""

"""tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.]]])"""

"""def _OneHot():
    v=torch.LongTensor([1,2,1,2,0])
    v_length=torch.LongTensor([2,3])
    v_onehot=onehot(v,v_length,4)
    print(v_onehot)
"""

5. 上采样函数

def assureRatio(img):
    """Ensure imgH<=imgW"""
    b,c,h,w=img.size()
    if h>=w:
        main=nn.UpsamplingBilinear2d(size=(h,h),scale_factor=None)
        img=main(img)
    return img

6. 初始化权重

def weight_init(model):
    for m in model.module():
        if isinstance(m,nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
        elif isinstance(m,nn.BatchNorm2d):
            nn.init.constant_(m.weight,1)
            nn.init.constant_(m.bias,0)
        elif isinstance(m,nn.Linear):
            nn.Linear.constant_(m.bias,0)
def weight_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv')!=-1:
        m.weight.data.normal_(0.0,0.02)
    elif classname.find('BatchNorm')!=-1:
        m.weight.data.normal_(1.0,0.02)
        m.bias.data.fill_(0)

crnn.apply(weight_init)

6. 各模块测试函数

posted @ 2021-11-11 21:23  Tsukinousag1  阅读(83)  评论(1)    收藏  举报