模块化函数(3) utils
1. 打开 txt 文件,构建词典
with open('../data/OCR-data/test.txt') as f:
data = f.readlines()
alphabet=[x.rstrip() for x in data]
alphabet=''.join(alphabet)
第一步:'我\n', '爱\n'
第二步:'我', '爱'
第三步:'我爱'
2. 构建映射encode-decode
首先明白.encode()和.decode()的type变化
a="我爱你啊"
a=a.encode()
#<class 'bytes'>
b=a.decode()
#<class 'str'>
构建映射表
class strLabelConverterForAttention(object):
def __init__(self,alphabet):
#构建dict映射
self.alphabet=alphabet
self.dict={}
self.dict['SOS']=0#开始
self.dict['EOS']=1#结束
self.dict['$']=2#blank标识符
for i,item in enumerate(self.alphabet):
self.dict[item]=i+3
'SOS':'0','EOS':'1','$':'2',....
2.1 Attention_encode
def encode(self,text):
if isinstance(text,str):
text=[self.dict[item] for item in text]
elif isinstance(text,collections.Iterable):
text=[self.encode(s) for s in text] # 编码
max_length=max([len(x) for x in text]) # 对齐
nb=len(text)
#在max_length和nb之间use ‘blank’ for padding
targets=torch.ones(nb,max_length+2)*2
for i in range(nb):
targets[i][0]=0
targets[i][1:len(text[i])+1]=text[i]
targets[i][len(text[i])+1]=1
text=targets.transpose(0,1).contiguous()
text=text.long()
return torch.LongTensor(text)
2.2 Attention_decode
def decode(self,t):
texts=list(self.dict.keys())[list(self.dict.values()).index(t)]
return texts
#单个单个decode
t=torch.FloatTensor([1])
print(list(self.dict.keys())[list(self.dict.values()).index(t)])
#SOS
2.3 CTC_encode
2.4 CTC_decode
3. 计算average
class average(object):
def __init__(self):
self.reset()
def add(self,v):
if isinstance(v,Variable):
count=v.data.numel()
v=v.data.sum()
elif isinstance(v,torch.Tensor):
count=v.numel()
v=v.sum()
self.n_count+=count
self.sum+=v
def reset(self):
self.n_count=0
self.sum=0
def val(self):
res=0
if self.n_count!=0:
res=self.sum/float(self.n_count)
return res
4. one_hot编码
def onehot(v,v_length,nc):
#v_ength:[batch_size,len]
#[[2],[3],[4]]
batch_size=v_length.size(0)
maxlength=v_length.max()#[4],若nc=4
v_onehot=torch.FloatTensor(batch_size,maxlength,nc).fill_(0)
acc=0
for i in range(batch_size):
length=v_length[i]#2
label=v[acc:acc+length].view(-1,1).long()#[0:2]
#print(v_onehot[i,:length,:])
v_onehot[i,:length].scatter_(1,label,1.0)
acc+=length
return v_onehot
"""tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])"""
"""tensor([[[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 0.]],
[[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]]])"""
"""def _OneHot():
v=torch.LongTensor([1,2,1,2,0])
v_length=torch.LongTensor([2,3])
v_onehot=onehot(v,v_length,4)
print(v_onehot)
"""
5. 上采样函数
def assureRatio(img):
"""Ensure imgH<=imgW"""
b,c,h,w=img.size()
if h>=w:
main=nn.UpsamplingBilinear2d(size=(h,h),scale_factor=None)
img=main(img)
return img
6. 初始化权重
def weight_init(model):
for m in model.module():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
nn.Linear.constant_(m.bias,0)
def weight_init(m):
classname=m.__class__.__name__
if classname.find('Conv')!=-1:
m.weight.data.normal_(0.0,0.02)
elif classname.find('BatchNorm')!=-1:
m.weight.data.normal_(1.0,0.02)
m.bias.data.fill_(0)
crnn.apply(weight_init)