模块化函数(1) CRNN_model
1. 双向LSTM
net:[in,hidden,out] [256,256,10]
input:[w,h,in] [16,16,256]
output:[w,h,out] [16,16,10]
class BidirectionalLSTM(nn.Module):
def __init__(self,nIn,nHidden,nOut):#[256,256,10]
super(BidirectionalLSTM, self).__init__()
self.rnn=nn.LSTM(nIn,nHidden,bidirectional=True)
self.embedding=nn.Linear(nHidden*2,nOut)
def forward(self,input):
#input:[16,16,256]
recurent,_=self.rnn(input)
#output,(h_n,c_n)
w,h,c=recurent.size()
#[16,16,512]
t_rec=recurent.view(w*h,c)
#将四维张量转换为二维张量之后,才能作为全连接层的输入
#[256,512] [长*宽,通道数:512]
output=self.embedding(t_rec)
#[256,10]
output=output.view(w,h,-1)
#[16,16,10]
return output
2. CRNN
net:[imgh,nc,nclass,nh] [32,1,37,256]
input:[b,c,h,w] [10,1,32,100]
output:[w,b,num_class] [26,10,37]
- 首先是 kernel_size,padding,stride,output参数,并且规定image的宽和高是[32,100]
#判断是否是16的倍数
ks=[3,3,3,3,3,3,2] #kernel_size
ps=[1,1,1,1,1,1,0] #padding
ss=[1,1,1,1,1,1,1] #stride
nm=[64,128,256,256,512,512,512] #output
- 建立网络 混合块
cnn=nn.Sequential()
- Conv+Relu单层
def convRelu(i,bathchNormlization=False):
nIn=nc if i==0 else nm[i-1]
nOut=nm[i]
cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn,nOut,ks[i],ss[i],ps[i]))
if bathchNormlization:
cnn.add_module('batchnorm{0}'.format(i),nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2,inplace=True))
else:
cnn.add_module('relu{0}'.format(i),nn.Relu(True))
- 搭建cnn与rnn网络
convRelu(0)
cnn.add_module('pooling{0}'.format(0),nn.MaxPool2d(2,2))
convRelu(1)
cnn.add_module('pooling{0}'.format(1),nn.MaxPool2d(2,2))
convRelu(2,True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2,2),(2,1),(0,1)))
convRelu(4,True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2,2),(2,1),(0,1)))
convRelu(6,True)
self.cnn=cnn
self.rnn=nn.Sequential(
BidirectionalLSTM(512,nh,nh),
BidirectionalLSTM(nh,nh,nclass)
)
- 最后一部分是计算,需要将cnn计算得出的[b,512,h,w]转为[w,b,512]转为[w,b,37]
def forward(self,input):
#input:[10,1,32,100]
conv=self.cnn(input)
#b,c,hidden,w:[batch,通道数,长,宽] [1字符长度][26宽度]
b,c,h,w=conv.size()
#[10,512,1,26]
assert h==1,"the height of conv must be 1"
conv=conv.squeeze(2)#取出h的维度
conv=conv.permute(2,0,1)
#[w,b,c]:[宽,batch,通道数]:[26,10,512]
#[宽,batch,通道数]——>[宽,batch,num_class]
#rnn features
output=self.rnn(conv)
#[26,10,37]
return output
代码
#crnn.model
import torch
import torch.nn as nn
import torch.nn.functional as F
class BidirectionalLSTM(nn.Module):
def __init__(self,nIn,nHidden,nOut):#[256,256,10]
super(BidirectionalLSTM, self).__init__()
self.rnn=nn.LSTM(nIn,nHidden,bidirectional=True)
self.embedding=nn.Linear(nHidden*2,nOut)
def forward(self,input):
#input:[16,16,256]
recurent,_=self.rnn(input)
#output,(h_n,c_n)
w,h,c=recurent.size()
#[16,16,512]
t_rec=recurent.view(w*h,c)
#[256,512] [长*宽,通道数:512]
output=self.embedding(t_rec)
#[256,10]
output=output.view(w,h,-1)
#[16,16,10]
return output
class CRNN(nn.Module):
def __init__(self,imgh,nc,nclass,nh,leakyRelu=False):
super(CRNN, self).__init__()
assert imgh%16==0,"imgh has to be a multiple of 16"
#判断是否是16的倍数
ks=[3,3,3,3,3,3,2]
ps=[1,1,1,1,1,1,0]
ss=[1,1,1,1,1,1,1]
nm=[64,128,256,256,512,512,512]
cnn=nn.Sequential()
#搭建conv+ReLU单层
def convRelu(i,bathchNormlization=False):
nIn=nc if i==0 else nm[i-1]
nOut=nm[i]
cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn,nOut,ks[i],ss[i],ps[i]))
if bathchNormlization:
cnn.add_module('batchnorm{0}'.format(i),nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2,inplace=True))
else:
cnn.add_module('relu{0}'.format(i),nn.Relu(True))
convRelu(0)
cnn.add_module('pooling{0}'.format(0),nn.MaxPool2d(2,2))
convRelu(1)
cnn.add_module('pooling{0}'.format(1),nn.MaxPool2d(2,2))
convRelu(2,True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2,2),(2,1),(0,1)))
convRelu(4,True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2,2),(2,1),(0,1)))
convRelu(6,True)
self.cnn=cnn
self.rnn=nn.Sequential(
BidirectionalLSTM(512,nh,nh),
BidirectionalLSTM(nh,nh,nclass)
)
def forward(self,input):
#input:[10,1,32,100]
conv=self.cnn(input)
#b,c,h,w:[batch,通道数,高,宽]
b,c,h,w=conv.size()
#[10,512,1,26]
assert h==1,"the height of conv must be 1"
conv=conv.squeeze(2)#取出h的维度
conv=conv.permute(2,0,1)
#[w,b,c]:[宽,batch,通道数]:[26,10,512]
#[宽,batch,通道数]——>[宽]
#rnn features
output=self.rnn(conv)
#[26,10,37]
return output
测试函数
if __name__ == '__main__':
img=torch.randn(10,1,32,100)
net=CRNN(32,1,37,256)
output=net(img)
print(output.shape)
#torch.Size([26, 10, 37])

浙公网安备 33010602011771号