RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构
一.基礎知識:
下圖是一個循環神經網絡實現語言模型的示例,可以看出其是基于當前的輸入與過去的輸入序列,預測序列的下一個字符.
序列特點就是某一步的輸出不僅依賴于這一步的輸入,還依賴于其他步的輸入或輸出.
其中n為批量大小,d為詞向量大小
1.RNN:
xt不止與該時刻輸入有關還與上一時刻的輸出狀態有關,而第t層的誤差函數跟輸出Ot直接相關,而Ot依賴于前面每一層的xi和si,故存在梯度消失或梯度爆炸的問題,對于長時序很難處理.所以可以進行改造讓第t層的誤差函數只跟該層{si,xi}有關.
RNN代碼簡單實現:
def one_hot(x, n_class, dtype=torch.float32):result = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape: (n, n_class)result.scatter_(1, x.long().view(-1, 1), 1) # result[i, x[i, 0]] = 1return resultdef to_onehot(X, n_class):return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]def get_parameters(num_inputs, num_hiddens,num_outputs):def init_parameter(shape):param = torch.zeros(shape, device=device,dtype=torch.float32)nn.init.normal_(param, 0, 0.01)return torch.nn.Parameter(param)#權重參數w_xh = init_parameter((num_inputs, num_hiddens))w_hh = init_parameter((num_hiddens, num_hiddens))b_h = torch.nn.Parameter(torch.zeros(num_hiddens,device=device))#輸出層參數w_hq = init_parameter((num_hiddens, num_outputs))b_q = torch.nn.Parameter(torch.zeros(num_outputs,device=device))return (w_xh, w_hh, b_h, w_hq, b_q)def rnn(inputs,state,params):w_xh, w_hh, b_h, w_hq, b_q = paramsH = stateoutputs = []for x in inputs:print('===x:', x) #(batch_size,vocab_size) (vocab_size, num_hiddens)H = torch.tanh(torch.matmul(x, w_xh)+torch.matmul(H, w_hh)+b_h)# (batch_size,num_hiddens) (num_hiddens, num_hiddens)Y = torch.matmul(H, w_hq)+b_q# (batch_size,num_hiddens) (num_hiddens, num_outputs)outputs.append(Y)return outputs, Hdef init_rnn_state(batch_size, num_hiddens,device):return torch.zeros((batch_size, num_hiddens),device=device)def test_one_hot():X = torch.arange(10).view(2, 5)print('==X:', X)inputs = to_onehot(X, 10)print(len(inputs))print('==inputs:', inputs)# print('==inputs:', inputs[-1].shape)def test_rnn():X = torch.arange(5).view(1, 5)print('===X:', X)num_hiddens = 256vocab_size = 10#詞典長度num_inputs, num_hiddens, num_outputs = vocab_size, num_hiddens, vocab_sizestate = init_rnn_state(X.shape[0], num_hiddens, device)inputs = to_onehot(X.to(device), vocab_size)print('===len(inputs), inputs', len(inputs), inputs)params = get_parameters(num_inputs, num_hiddens, num_outputs)outputs, state_new = rnn(inputs, state, params)print('==len(outputs), outputs[0].shape:', len(outputs), outputs[0].shape)print('==state.shape:', state.shape)print('==state_new.shape:', state_new.shape)if __name__ == '__main__':# test_one_hot()test_rnn()2.LSTM:
傳統RNN每個模塊內只是一個簡單的tanh層:
遺忘門:控制上一時間步的記憶細胞;
輸入門:控制當前時間步的輸入;
輸出門:控制從記憶細胞到隱藏狀態;
記憶細胞:?種特殊的隱藏狀態的信息的流動,表示的是長期記憶;
h 是隱藏狀態,表示的是短期記憶;
LSTM每個循環的模塊內又有4層結構:3個sigmoid層,1個tanh層
細胞狀態Ct,類似short cut信息流通暢順,故可以解決梯度消失或爆炸的問題.
遺忘層,決定信息保留多少
更新層,這里要注意的是用了tanh,值域在-1,1,起到信息加強和減弱的作用.
輸出層,上述兩層的信息相加流通到這里以后,經過tanh函數得到輸出值候選項,而候選項中的哪些部分最終會被輸出由一個sigmoid層來決定.這時就得到了輸出狀態和輸出值,下一時刻也是如此.
?
LSTM簡單實現代碼:
def one_hot(x, n_class, dtype=torch.float32):result = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape: (n, n_class)result.scatter_(1, x.long().view(-1, 1), 1) # result[i, x[i, 0]] = 1return resultdef to_onehot(X, n_class):return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]def get_parameters(num_inputs, num_hiddens,num_outputs):def init_parameter(shape):param = torch.zeros(shape, device=device,dtype=torch.float32)nn.init.normal_(param, 0, 0.01)return torch.nn.Parameter(param)def final_init_parameter():return (init_parameter((num_inputs, num_hiddens)),init_parameter((num_hiddens, num_hiddens)),torch.nn.Parameter(torch.zeros(num_hiddens,device=device,dtype=torch.float32,requires_grad=True)))w_xf, w_hf, b_f = final_init_parameter()#遺忘門參數w_xi, w_hi, b_i = final_init_parameter()#輸入門參數w_xo, w_ho, b_o = final_init_parameter()#輸出門參數w_xc, w_hc, b_c = final_init_parameter()#記憶門參數w_hq = init_parameter((num_hiddens, num_outputs))#輸出層參數b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32, requires_grad=True))return nn.ParameterList([w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hq, b_q])def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))def lstm(inputs, states, params):[w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hq, b_q] = params[H, C] = statesoutputs = []for x in inputs:print('===x:',x)I = torch.sigmoid(torch.matmul(x, w_xi) + torch.matmul(H, w_hi) + b_i)#輸入門數據F = torch.sigmoid(torch.matmul(x, w_xf) + torch.matmul(H, w_hf) + b_f)#遺忘門數據O = torch.sigmoid(torch.matmul(x, w_xo) + torch.matmul(H, w_ho) + b_o)#輸出門數據C_tila = torch.tanh(torch.matmul(x, w_xc) + torch.matmul(H, w_hc) + b_c)#C冒數據C = F*C + I*C_tilaH = torch.tanh(C)*O# print('H.shape', H.shape)# print('w_hq.shape', w_hq.shape)# print('b_q.shape:', b_q.shape)Y = torch.matmul(H, w_hq)+b_qoutputs.append(Y)return outputs, (H,C)def test_lstm():batch_size = 1X = torch.arange(5).view(batch_size, 5)print('===X:', X)num_hiddens = 256vocab_size = 10 # 詞典長度inputs = to_onehot(X.to(device), vocab_size)print('===len(inputs), inputs', len(inputs), inputs)num_inputs, num_hiddens, num_outputs = vocab_size, num_hiddens, vocab_sizestates = init_lstm_state(batch_size, num_hiddens, device='cpu')params = get_parameters(num_inputs, num_hiddens, num_outputs)outputs, new_states = lstm(inputs, states, params)H, C = new_statesprint('===H.shape', H.shape)print('===C.shape', C.shape)print('===len(outputs), outputs[0].shape:', len(outputs), outputs[0].shape) if __name__ == '__main__':# test_one_hot()# test_rnn()test_lstm()?
3.Seq2seq模型在于,encoder層,由雙層lstm實現隱藏狀態編碼信息,decoder層由雙層lstm將encode層隱藏狀態編碼信息解碼出來,這樣也造成了decoder依賴最終時間步的隱藏狀態,且RNN機制實際中存在長程梯度消失的問題,對于較長的句子,所以隨著所需翻譯句子的長度的增加,這種結構的效果會顯著下降,也就引入后面的attention。與此同時,解碼的目標詞語可能只與原輸入的部分詞語有關,而并不是與所有的輸入有關。 ?例如,當把“Hello world”翻譯成“Bonjour le monde”時,“Hello”映射成“Bonjour”,“world”映射成“monde”。 # 在seq2seq模型中, 解碼器只能隱式地從編碼器的最終狀態中選擇相應的信息。然而,注意力機制可以將這種選擇過程顯式地建模。
Seq2seq代碼案例,batch為4,單詞長度為7,每個單詞對應的embedding向量為8,lstm為兩層
import torch.nn as nn import d2l import torch import math#由于依賴最終時間步的隱藏狀態,RNN機制實際中存在長程梯度消失的問題,對于較長的句子, # 我們很難寄希望于將輸入的序列轉化為定長的向量而保存所有的有效信息, # 所以隨著所需翻譯句子的長度的增加,這種結構的效果會顯著下降。 #與此同時,解碼的目標詞語可能只與原輸入的部分詞語有關,而并不是與所有的輸入有關。 # 例如,當把“Hello world”翻譯成“Bonjour le monde”時,“Hello”映射成“Bonjour”,“world”映射成“monde”。 # 在seq2seq模型中, # 解碼器只能隱式地從編碼器的最終狀態中選擇相應的信息。然而,注意力機制可以將這種選擇過程顯式地建模。#雙層lstm實現隱藏層編碼信息encode class Seq2SeqEncoder(d2l.Encoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, embed_size)#每個字符編碼成一個向量self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers, dropout=dropout, batch_first=False)def begin_state(self, batch_size, device):#(H, C)return [torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens), device=device),torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens), device=device)]def forward(self, X, *args):X = self.embedding(X) # X shape: (batch_size, seq_len, embed_size)print('===encode X.shape', X.shape)X = X.transpose(0, 1) # (seq_len, batch_size, embed_size)print('===encode X.shape', X.shape)state = self.begin_state(X.shape[1], device=X.device)out, state = self.rnn(X,state)print('===encode out.shape:', out.shape)#(seq_len, batch_size, num_hiddens)H, C = stateprint('===encode H.shape:', H.shape)#(num_layers, batch_size, num_hiddens)print('===encode C.shape:', C.shape)#(num_layers, batch_size, num_hiddens)return out, state#雙層lstm將encode層隱藏層信息解碼出來 class Seq2SeqDecoder(d2l.Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqDecoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, *args):return enc_outputs[1]def forward(self, X, state):X = self.embedding(X).transpose(0, 1)print('==decode X.shape', X.shape)# (seq_len, batch_size, embed_size)out, state = self.rnn(X, state)print('==decode out.shape:', out.shape)# (seq_len, batch_size, num_hiddens)H, C = stateprint('==decode H.shape:', H.shape) # (num_layers, batch_size, num_hiddens)print('==decode C.shape:', C.shape) # (num_layers, batch_size, num_hiddens)# Make the batch to be the first dimension to simplify loss computation.out = self.dense(out).transpose(0, 1)# (batch_size, seq_len, vocab_size)print('==decode final out.shape', out.shape)return out, statedef SequenceMask(X, X_len,value=0):print(X)print(X_len)print(X_len.device)maxlen = X.size(1)print('==torch.arange(maxlen)[None, :]:', torch.arange(maxlen)[None, :])print('==X_len[:, None]:', X_len[:, None])mask = torch.arange(maxlen)[None, :] < X_len[:, None]print(mask)X[~mask] = valueprint('X:', X)return Xdef masked_softmax(X, valid_length):# X: 3-D tensor, valid_length: 1-D or 2-D tensorsoftmax = nn.Softmax(dim=-1)if valid_length is None:return softmax(X)else:shape = X.shapeif valid_length.dim() == 1:try:valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0)) # [2,2,3,3]except:valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0)) # [2,2,3,3]else:valid_length = valid_length.reshape((-1,))# fill masked elements with a large negative, whose exp is 0X = SequenceMask(X.reshape((-1, shape[-1])), valid_length)return softmax(X).reshape(shape) class MLPAttention(nn.Module):def __init__(self, ipt_dim, units, dropout, **kwargs):super(MLPAttention, self).__init__(**kwargs)# Use flatten=True to keep query's and key's 3-D shapes.self.W_k = nn.Linear(ipt_dim, units, bias=False)self.W_q = nn.Linear(ipt_dim, units, bias=False)self.v = nn.Linear(units, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, valid_length):query, key = self.W_k(query), self.W_q(key)print("==query.size, key.size::", query.size(), key.size())# expand query to (batch_size, #querys, 1, units), and key to# (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.print('query.unsqueeze(2).shape', query.unsqueeze(2).shape)print('key.unsqueeze(1).shape', key.unsqueeze(1).shape)features = query.unsqueeze(2) + key.unsqueeze(1)#print("features:",features.size()) #--------------開啟scores = self.v(features).squeeze(-1)print('===scores:', scores.shape)attention_weights = self.dropout(masked_softmax(scores, valid_length))return torch.bmm(attention_weights, value)def test_encoder():encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)X = torch.zeros((4, 7), dtype=torch.long) # (batch_size, seq_len)output, state = encoder(X)def test_decoder():X = torch.zeros((4, 7), dtype=torch.long) # (batch_size, seq_len)encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)state = decoder.init_state(encoder(X))out, state = decoder(X, state)def test_loss():X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])SequenceMask(X, torch.FloatTensor([2, 3]))def test_dot():keys = torch.ones((2, 10, 2), dtype=torch.float)values = torch.arange((40), dtype=torch.float).view(1, 10, 4).repeat(2, 1, 1)print('==values.shape:', values.shape)# print(values)atten = MLPAttention(ipt_dim=2, units=8, dropout=0)atten(torch.ones((2, 1, 2), dtype=torch.float), keys, values, torch.FloatTensor([2, 6]))if __name__ == '__main__':test_encoder()# test_decoder()encode輸出:?
decode輸出:
二.基于pytorch的crnn網絡結構
地址:https://github.com/zonghaofan/crnn_pytorch
1.網絡圖:
首先卷積提取特征以后再用兩層雙向lstm提取時序特征
2.代碼實現
import torch.nn as nn import torch.nn.functional as F import torchclass BiLSTM(nn.Module):def __init__(self,nIn,nHidden,nOut):super(BiLSTM,self).__init__()self.lstm=nn.LSTM(input_size=nIn,hidden_size=nHidden,bidirectional=True)self.embdding=nn.Linear(nHidden*2,nOut)#Sequence batch channels (W,b,c)def forward(self, input):recurrent,_=self.lstm(input)S,b,h=recurrent.size()S_line = recurrent.view(S*b,h)output=self.embdding(S_line)#[S*b,nout]output=output.view(S,b,-1)return outputclass CRNN(nn.Module):def __init__(self,imgH,imgC,nclass,nhidden):assert imgH==32super(CRNN,self).__init__()cnn = nn.Sequential()cnn.add_module('conv{}'.format(0), nn.Conv2d(imgC, 64, 3, 1, 1))cnn.add_module('relu{}'.format(0), nn.ReLU(True))cnn.add_module('pooling{}'.format(0),nn.MaxPool2d(2,2))cnn.add_module('conv{}'.format(1), nn.Conv2d(64, 128, 3, 1, 1))cnn.add_module('relu{}'.format(1), nn.ReLU(True))cnn.add_module('pooling{}'.format(1), nn.MaxPool2d(2, 2))cnn.add_module('conv{}'.format(2), nn.Conv2d(128, 256, 3, 1, 1))cnn.add_module('relu{}'.format(2), nn.ReLU(True))cnn.add_module('conv{}'.format(3), nn.Conv2d(256, 256, 3, 1, 1))cnn.add_module('relu{}'.format(3), nn.ReLU(True))cnn.add_module('pooling{}'.format(3), nn.MaxPool2d((1,2), 2))cnn.add_module('conv{}'.format(4), nn.Conv2d(256, 512, 3, 1, 1))cnn.add_module('relu{}'.format(4), nn.ReLU(True))cnn.add_module('BN{}'.format(4), nn.BatchNorm2d(512))cnn.add_module('conv{}'.format(5), nn.Conv2d(512, 512, 3, 1, 1))cnn.add_module('relu{}'.format(5), nn.ReLU(True))cnn.add_module('BN{}'.format(5), nn.BatchNorm2d(512))cnn.add_module('pooling{}'.format(5), nn.MaxPool2d((1, 2), 2))cnn.add_module('conv{}'.format(6), nn.Conv2d(512, 512, 2, 1, 0))cnn.add_module('relu{}'.format(6), nn.ReLU(True))self.cnn=cnnself.rnn=nn.Sequential(BiLSTM(512,nhidden,nhidden),BiLSTM(nhidden, nhidden, nclass))def forward(self,input):conv = self.cnn(input)print('conv.size():',conv.size())b,c,h,w=conv.size()assert h==1conv=conv.squeeze(2)#b ,c wconv=conv.permute(2,0,1) #w,b,crnn_out=self.rnn(conv)print('rnn_out.size():',rnn_out.size())out=F.log_softmax(rnn_out,dim=2)print('out.size():',out.size())return out def lstm_test():print('===================LSTM===========================')model = BiLSTM(512, 256, 5600)print(model)x = torch.rand((41, 32, 512))print('input:', x.size())out = model(x)print(out.size()) def crnn_test():print('===================CRNN===========================')model = CRNN(32, 1, 3600, 256)print(model)x = torch.rand((32, 1, 32, 200)) # b c h wprint('input:', x.size())out = model(x)print(out.size()) if __name__ == '__main__':lstm_test()crnn_test()lstm輸出:
#crnn輸出
3.提特征輸入ctc過程
上面代碼可看成,輸入為(32,1,32,200)cnn提取特征過后,每張圖片11個特征向量,每個特征向量長度為512,在LSTM中一個時間步就傳入一個特征向量進行分類。一個特征向量就相當于原圖中的一個小矩形區域,RNN的目標就是預測這個矩形區域為哪個字符,即根據輸入的特征向量,進行預測,得到所有字符的softmax概率分布,這是一個長度為字符類別數的向量,作為CTC層的輸入。如下圖所示就是輸入ctc的示例圖
?
4.ctc loss
首先思考ctc解決什么問題,一般分類就是一張圖片對應一類,那樣拉一個全連接進行softmax即可分類,對于這種一張圖片有好幾個字符,上述就解決不了,故有一種思路是將輸入圖片的字符切割出來在進行分類,那這樣的問題是分割不準怎么辦?所以面臨這種輸入類別不定長的時候,就可以利用ctc進行解決,ctc的思想就是將輸入圖片提取特征變成時序步長,給出輸入時序步長X的所有可能結果Y的輸出分布。那么根據這個分布,我們可以輸出最可能的結果。大膽猜測對于有一張個頭差不多的貓貓與狗狗水平挨著的圖片,ctc可能也能解決分類問題.
ctc的損失函數可以對CNN和RNN進行端到端的聯合訓練。
RNN這里有去冗余操作,例如,上圖中RNN中有5個時間步,但最終輸出兩個字符,理想情況下 t0, t1, t2時刻都應映射為“a”,t3, t4 時刻都應映射為“b”,然后連接起來得到“aaabb”,那么合并結果為“ab”。但是在識別book這類字符會有問題。最后以“-”符號代表blank,RNN 輸出序列時,在文本標簽中的重復的字符之間插入一個“-”,比如輸出序列為“bbooo-ookk”,則最后將被映射為“book”,即有blank字符隔開的話,連續相同字符就不進行合并。即對字符序列先刪除連續重復字符,然后從路徑中刪除所有“-”字符,這個稱為解碼過程,而編碼則是由神經網絡來實現。引入blank機制,我們就可以很好地解決重復字符的問題。
4.1訓練過程
其中t0,t1代表兩個時間步長,黑色線代表a字符的路徑,虛線代表空文本路徑。
例如:對于時序步長為2的識別,有兩個時間步長(t0,t1)和三個可能的字符為“a”,“b”和“-”,我們得到兩個概率分布向量,如果采取最大概率路徑解碼的方法,則“--”的概率最大,即真實字符為空的概率為0.6*0.6=0.36。
但是為字符“a”的情況有多種組合,“aa”, “a-“和“-a”都是代表“a”,所以,輸出“a”的概率應該為三種之和:
0.4*0.4+0.4*0.6+0.4*0.6=0.64 ,故a的概率最高。如果標簽文本為“a”,則通過計算圖像中為“a”的所有可能的對齊組合(或者路徑)的分數之和來計算損失函數。
對于RNN給定輸入概率分布矩陣為y={y1,y2,...,yT},T是序列長度,最后映射為標簽文本l的總概率為:
其中B(π)代表從序列到序列的映射函數B變換后是文本l的所有路徑集合,而π則是其中的一條路徑。每條路徑的概率為各個時間步中對應字符的分數的乘積。然后訓練網絡使得這個概率值最大化,類似于普通的分類,CTC的損失函數定義為概率的負最大似然函數,為了計算方便,對似然函數取對數。
然后通過對損失函數的計算,就可以對之前的神經網絡進行反向傳播,神經網絡的參數根據所使用的優化器進行更新,從而找到最可能的像素區域對應的字符。這種通過映射變換和所有可能路徑概率之和的方式使得CTC不需要對原始的輸入字符序列進行準確的切分。
4.2推理過程
推理階段,過程與訓練階段有所不同,我們用訓練好的神經網絡來識別新的文本圖像。如果我們像上面一樣將每種可能文本的所有路徑計算出來,對于很長的時間步和很長的字符序列來說,計算量是非常龐大。
由于RNN在每一個時間步的輸出為所有字符類別的概率分布,所以,我們取其中最大概率的字符作為該時間步的輸出字符,然后將所有時間步得到一個字符進行拼接得到一個序列路徑,即最大概率路徑,再根據上面介紹的合并序列方法得到最終的預測文本結果。
如上圖5個時間步長,輸出結果為a->a->a->blank>b,合并去重結果就為ab。要注意的是字符之間有間距,需要添加blank。
4.3ctc loss代碼示例
import torch from torch import nnT = 50 #時序步長 C = 20 #類別數 排除blank N = 2 # Batch size S = 30 #一個batch中的label的最大時序步長 S_min = 10 #一個batch中label的最小字符個數# Initialize random batch of input vectors, for *size = (T,N,C) #rnn輸出結果 input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() print('==input.shape:', input.shape) #字符對應的label idx target = torch.randint(low=1, high=C+1, size=(N, S), dtype=torch.long) print('==target:', target) #序列長度的值 input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) print('==input_lengths.shape:', input_lengths.shape) print('=input_lengths:', input_lengths) #字符的長度 target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) print('==target_lengths.shape:', target_lengths.shape) print('==target_lengths:', target_lengths)ctc_loss = nn.CTCLoss() loss = ctc_loss(input, target, input_lengths, target_lengths)5.一些可能改動的點
最后兩層pooling設置為h=1,w=2的矩形,是因為文本大多數是高小而寬長,這樣就可以不丟失寬度信息,利于區分i和L.?
如果數字過小,那就可以讓橫向長度不變,pool可以換成如下,這樣橫向長度基本不變,縱向減少兩倍。
pool2 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) x=torch.rand((32,1,32,100)) print('=========input========') print(x.shape) print('=========output========') pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2)) y = pool(x) print(y.shape) # (h-2)/2+1 (w-1)/1+1 pool = nn.MaxPool2d(kernel_size=(2,1),stride=(2,1)) y = pool(x) print(y.shape)# (h-2)+2*p/2+1 (w-2)+2*p/1+1 pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(1,0)) y = pool(x) print(y.shape)6.finetune新加字符
由于原先數據集不一定找得到,對于新加的字符,對除了最后一層的全連接進行凍結,例如原先最后一層是(512,5000),現在新加10個字符,變為(512,5010),則將原先的那一層權重矩陣平移過來.只需要訓練(512,10)的矩陣.
參考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://www.cnblogs.com/zhangchaoyang/articles/6684906.html
https://aijishu.com/a/1060000000135614
https://www.cnblogs.com/ydcode/p/11038064.html
總結
以上是生活随笔為你收集整理的RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JavaSE——异常处理(异常简介、tr
- 下一篇: 应用程序利用ADO对象访问数据库