世界人工智能大赛 Top1 方案!手写体 OCR 识别
?Datawhale干貨?
作者:王浩,結行科技算法工程師
參加了“世界人工智能創新大賽”——手寫體 OCR 識別競賽(任務一),取得了Top1的成績。隊伍隨機組的,有人找我我就加了進來,這是我第一次做OCR相關的項目,所以隨意起了個名字。下面通過這篇文章來詳細介紹我們的方案。
實踐背景
賽題背景
銀行日常業務中涉及到各類憑證的識別錄入,例如身份證錄入、支票錄入、對賬單錄入等。以往的錄入方式主要是以人工錄入為主,效率較低,人力成本較高。近幾年來,OCR相關技術以其自動執行、人為干預較少等特點正逐步替代傳統的人工錄入方式。但OCR技術在實際應用中也存在一些問題,在各類憑證字段的識別中,手寫體由于其字體差異性大、字數不固定、語義關聯性較低、憑證背景干擾等原因,導致OCR識別率準確率不高,需要大量人工校正,對日常的銀行錄入業務造成了一定的影響。
賽題地址:http://ailab.aiwin.org.cn/competitions/65
賽題任務
本次賽題將提供手寫體圖像切片數據集,數據集從真實業務場景中,經過切片脫敏得到,參賽隊伍通過識別技術,獲得對應的識別結果。即:
輸入:手寫體圖像切片數據集
輸出:對應的識別結果
本任務提供開放可下載的訓練集及測試集,允許線下建模或線上提供 Notebook 環境及 Terminal 容器環境(脫網)建模,輸出識別結果完成賽題。
賽題數據
A. 數據規模和內容覆蓋
B.數據示例
原始手寫體圖像共分為三類,分別涉及銀行名稱、年月日、金額三大類,分別示意如下:
相應圖片切片中可能混雜有一定量的干擾信息,分別示例如下:
識別結果 JSON 在訓練集中的格式如下:
json 文件內容規范: {"image1":?"陸萬捌千零貳拾伍元整","image2":?"付經管院工資","image3":?"",... }實踐方案
通過在網上查閱資料,得知OCR比賽最常用的模型是CRNN+CTC。所以我最開始也是采用這個方案。
上圖是我找到的資料,有好多個版本。因為是第一次做OCR的項目,所以我優先選擇有數據集的項目,這樣可以快速的了解模型的輸入輸出。
所以我選擇的第一個Attention_ocr.pytorch-master.zip,從名字上可以看出這個是加入注意力機制,感覺效果會好一些。
構建數據集
下圖是Attention_ocr.pytorch-master.zip自帶的數據集截圖,從截圖上可以看出,數據的格式:“圖片路徑+空格+標簽”。我們也需要按照這樣的格式構建數據集。
新建makedata.py文件,插入下面的代碼。
import?os import?json #官方給的數據集 image_path_amount?=?"./data/train/amount/images"? image_path_date?=?"./data/train/date/images" #增強數據集 image_path_test='./data/gan_test_15000/images/0' image_path_train='./data/gan_train_15500_0/images/0' amount_list?=?os.listdir(image_path_amount) amount_list?=?os.listdir(image_path_amount)new_amount_list?=?[] for?filename?in?amount_list:new_amount_list.append(image_path_amount?+?"/"?+?filename)date_list?=?os.listdir(image_path_date) new_date_list?=?[] for?filename?in?date_list:new_date_list.append(image_path_date?+?"/"?+?filename) new_test_list?=?[] for?filename?in?amount_list:new_test_list.append(image_path_amount?+?"/"?+?filename)new_train_list?=?[] for?filename?in?amount_list:new_train_list.append(image_path_amount?+?"/"?+?filename)image_path_amount和image_path_date是官方給定的數據集路徑。
image_path_test和image_path_train是增強的數據集(在后面會講如何做增強)
創建建立list,保存圖片的路徑。
amount_json?=?"./data/train/amount/gt.json" date_json?=?"./data/train/date/gt.json" train_json?=?"train_data.json" test_json?=?"test_data.json" with?open(amount_json,?"r",?encoding='utf-8')?as?f:load_dict_amount?=?json.load(f) with?open(date_json,?"r",?encoding='utf-8')?as?f:load_dict_date?=?json.load(f) with?open(train_json,?"r",?encoding='utf-8')?as?f:load_dict_train?=?json.load(f) with?open(test_json,?"r",?encoding='utf-8')?as?f:load_dict_test?=?json.load(f)四個json文件對應上面的四個list,json文件存儲的是圖片的名字和圖片的標簽,把json解析出來存到字典中。
#聚合list all_list?=?new_amount_list?+?new_date_list+new_test_list+new_train_list from?sklearn.model_selection?import?train_test_split #切分訓練集合和驗證集 train_list,?test_list?=?train_test_split(all_list,?test_size=0.15,?random_state=42) #聚合字典 all_dic?=?{} all_dic.update(load_dict_amount) all_dic.update(load_dict_date) all_dic.update(load_dict_train) all_dic.update(load_dict_test) with?open('train.txt',?'w')?as?f:for?line?in?train_list:f.write(line?+?"?"?+?all_dic[line.split('/')[-1]]+"\n") with?open('val.txt',?'w')?as?f:for?line?in?test_list:f.write(line?+?"?"?+?all_dic[line.split('/')[-1]]+"\n")將四個list聚合為一個list。
使用train_test_split切分訓練集和驗證集。
聚合字典。
然后分別遍歷trainlist和testlist,將其寫入train.txt和val.txt。
到這里數據集就制作完成了。得到train.txt和val.txt
查看train.txt
數據集和自帶的數據集格式一樣了,然后我們就可以開始訓練了。
獲取class
新建getclass.py文件夾,加入以下代碼:
import?jsonamount_json?=?"./data/train/amount/gt.json" date_json?=?"./data/train/date/gt.json" with?open(amount_json,?"r",?encoding='utf-8')?as?f:load_dict_amount?=?json.load(f) with?open(date_json,?"r",?encoding='utf-8')?as?f:load_dict_date?=?json.load(f) all_dic?=?{} all_dic.update(load_dict_amount) all_dic.update(load_dict_date) list_key=[] for?keyline?in?all_dic.values():for?key?in?keyline:if?key?not?in?list_key:list_key.append(key) with?open('data/char_std_5990.txt',?'w')?as?f:for?line?in?list_key:f.write(line+"\n")執行完就可以得到存儲class的txt文件。打開char_std_5990.txt,看到有21個類。
模型改進
crnn的卷積部分類似VGG,我對模型的改進主要有一下幾個方面:
1、加入激活函數Swish。
2、加入BatchNorm。
3、加入SE注意力機制。
4、適當加深模型。
代碼如下:
self.cnn?=?nn.Sequential(nn.Conv2d(nc,?64,?3,?1,?1),?Swish(),?nn.BatchNorm2d(64),nn.MaxPool2d(2,?2),??#?64x16x50nn.Conv2d(64,?128,?3,?1,?1),?Swish(),?nn.BatchNorm2d(128),nn.MaxPool2d(2,?2),??#?128x8x25nn.Conv2d(128,?256,?3,?1,?1),?nn.BatchNorm2d(256),?Swish(),??#?256x8x25nn.Conv2d(256,?256,?3,?1,?1),?nn.BatchNorm2d(256),?Swish(),??#?256x8x25SELayer(256,?16),nn.MaxPool2d((2,?2),?(2,?1),?(0,?1)),??#?256x4x25nn.Conv2d(256,?512,?3,?1,?1),?nn.BatchNorm2d(512),?Swish(),??#?512x4x25nn.Conv2d(512,?512,?1),?nn.BatchNorm2d(512),?Swish(),nn.Conv2d(512,?512,?3,?1,?1),?nn.BatchNorm2d(512),?Swish(),??#?512x4x25SELayer(512,?16),nn.MaxPool2d((2,?2),?(2,?1),?(0,?1)),??#?512x2x25nn.Conv2d(512,?512,?2,?1,?0),?nn.BatchNorm2d(512),?Swish())??#?512x1x25SE和Swish
class?SELayer(nn.Module):def?__init__(self,?channel,?reduction=16):super(SELayer,?self).__init__()self.avg_pool?=?nn.AdaptiveAvgPool2d(1)self.fc?=?nn.Sequential(nn.Linear(channel,?channel?//?reduction,?bias=True),nn.LeakyReLU(inplace=True),nn.Linear(channel?//?reduction,?channel,?bias=True),nn.Sigmoid())def?forward(self,?x):b,?c,?_,?_?=?x.size()y?=?self.avg_pool(x).view(b,?c)y?=?self.fc(y).view(b,?c,?1,?1)return?x?*?y.expand_as(x)class?Swish(nn.Module):def?forward(self,?x):return?x?*?torch.sigmoid(x)模型訓練
打開train.py ,在訓練之前,我們還要調節一下參數。
parser?=?argparse.ArgumentParser() parser.add_argument('--trainlist',??default='train.txt') parser.add_argument('--vallist',??default='val.txt') parser.add_argument('--workers',?type=int,?help='number?of?data?loading?workers',?default=0) parser.add_argument('--batchSize',?type=int,?default=4,?help='input?batch?size') parser.add_argument('--imgH',?type=int,?default=32,?help='the?height?of?the?input?image?to?network') parser.add_argument('--imgW',?type=int,?default=512,?help='the?width?of?the?input?image?to?network') parser.add_argument('--nh',?type=int,?default=512,?help='size?of?the?lstm?hidden?state') parser.add_argument('--niter',?type=int,?default=300,?help='number?of?epochs?to?train?for') parser.add_argument('--lr',?type=float,?default=0.00005,?help='learning?rate?for?Critic,?default=0.00005') parser.add_argument('--beta1',?type=float,?default=0.5,?help='beta1?for?adam.?default=0.5') parser.add_argument('--cuda',?action='store_true',?help='enables?cuda',?default=True) parser.add_argument('--ngpu',?type=int,?default=1,?help='number?of?GPUs?to?use') parser.add_argument('--encoder',?type=str,?default='',?help="path?to?encoder?(to?continue?training)") parser.add_argument('--decoder',?type=str,?default='',?help='path?to?decoder?(to?continue?training)') parser.add_argument('--experiment',?default='./expr/attentioncnn',?help='Where?to?store?samples?and?models') parser.add_argument('--displayInterval',?type=int,?default=100,?help='Interval?to?be?displayed') parser.add_argument('--valInterval',?type=int,?default=1,?help='Interval?to?be?displayed') parser.add_argument('--saveInterval',?type=int,?default=1,?help='Interval?to?be?displayed') parser.add_argument('--adam',?default=True,?action='store_true',?help='Whether?to?use?adam?(default?is?rmsprop)') parser.add_argument('--adadelta',?action='store_true',?help='Whether?to?use?adadelta?(default?is?rmsprop)') parser.add_argument('--keep_ratio',default=True,?action='store_true',?help='whether?to?keep?ratio?for?image?resize') parser.add_argument('--random_sample',?default=True,?action='store_true',?help='whether?to?sample?the?dataset?with?random?sampler') parser.add_argument('--teaching_forcing_prob',?type=float,?default=0.5,?help='where?to?use?teach?forcing') parser.add_argument('--max_width',?type=int,?default=129,?help='the?width?of?the?featuremap?out?from?cnn') parser.add_argument("--output_file",?default='deep_model.log',?type=str,?required=False) opt?=?parser.parse_args()trainlist:訓練集,默認是train.txt。
vallist:驗證集路徑,默認是val.txt。
batchSize:批大小,根據顯存大小設置。
imgH:圖片的高度,crnn模型默認為32,這里不需要修改。
imgW:圖片寬度,我在這里設置為512。
keep_ratio:設置為True,設置為True后,程序會保持圖片的比率,然后在一個batch內統一尺寸,這樣訓練的模型精度更高。
lr:學習率,設置為0.00005,這里要注意,不要太大,否則不收斂。
其他的參數就不一一介紹了,大家可以自行嘗試。
運行結果:
運行結果訓練完成后,可以在expr文件夾下面找到模型。
訓練的模型結果預測
在推理之前,我們還需要確認最長的字符串,新建getmax.py,添加如下代碼:
import?os import?jsonimage_path_amount?=?"./data/train/amount/images" image_path_date?=?"./data/train/date/images" amount_list?=?os.listdir(image_path_amount) new_amount_list?=?[] for?filename?in?amount_list:new_amount_list.append(image_path_amount?+?"/"?+?filename) date_list?=?os.listdir(image_path_date) new_date_list?=?[] for?filename?in?date_list:new_date_list.append(image_path_date?+?"/"?+?filename) amount_json?=?"./data/train/amount/gt.json" date_json?=?"./data/train/date/gt.json" with?open(amount_json,?"r",?encoding='utf-8')?as?f:load_dict_amount?=?json.load(f) with?open(date_json,?"r",?encoding='utf-8')?as?f:load_dict_date?=?json.load(f) all_list?=?new_amount_list?+?new_date_list from?sklearn.model_selection?import?train_test_splitall_dic?=?{} all_dic.update(load_dict_amount) all_dic.update(load_dict_date)maxLen?=?0 for?i?in?all_dic.values():if?(len(i)?>?maxLen):maxLen?=?len(i) print(maxLen)運行結果:28
將test.py中的max_length設置為28。
修改模型的路徑,包括encoder_path和decoder_path。
encoder_path?=?'./expr/attentioncnn/encoder_22.pth'decoder_path?=?'./expr/attentioncnn/decoder_22.pth'修改測試集的路徑:
for?path?in?tqdm(glob.glob('./data/測試集/date/images/*.jpg')):text,?prob?=?test(path)if?prob<0.8:count+=1result_dict[os.path.basename(path)]?=?{'result':?text,'confidence':?prob}for?path?in?tqdm(glob.glob('./data/測試集/amount/images/*.jpg')):text,?prob?=?test(path)if?prob<0.8:count+=1result_dict[os.path.basename(path)]?=?{'result':?text,'confidence':?prob}寫到最后
作者第一次參加OCR相關的賽事,在任務一中取得Top1的好成績,背后的付出和努力通過方案分享也能看到。近期接觸到很多在比賽中拿到不錯成績的小伙伴,不少是第一次嘗試。所以,努力后還是可以得到自己滿意的結果的。
整理不易,點贊三連↓
總結
以上是生活随笔為你收集整理的世界人工智能大赛 Top1 方案!手写体 OCR 识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 用于主题检测的临时日志(c48534c5
- 下一篇: 速达5000出现计算成本数据溢出的问题