Python3实现机器学习经典算法(二)KNN实现简单OCR
一、前言
1、ocr概述
OCR (Optical Character Recognition,光學字符識別)是指電子設備(例如掃描儀或數碼相機)檢查紙上打印的字符,通過檢測暗、亮的模式確定其形狀,然后用字符識別方法將形狀翻譯成計算機文字的過程;即,針對印刷體字符,采用光學的方式將紙質文檔中的文字轉換成為黑白點陣的圖像文件,并通過識別軟件將圖像中的文字轉換成文本格式,供文字處理軟件進一步編輯加工的技術(摘自百度百科:光學字符識別)。
KNN在OCR的識別過程中能發揮作用的地方在于將圖像中的文字轉換為文本格式,而OCR的其他部分,比如圖像預處理、二值化等操作將其丟給OpenCV去操作。
2、訓練集簡介
由于我們采用的是KNN來轉換圖像中的文字為文本格式,需要一個龐大的手寫字符訓練集來支撐我們的算法。這里我使用的是《機器學習實戰》2.3實例:手寫識別系統中使用的數據集,其下載地址為:https://www.manning.com/books/machine-learning-in-action,在Source CodeCh02digits rainingDigits中的兩千多個手寫字符既是我所使用的訓練集。
這個訓練集配合上它所提供的測試集,提供了一個準確度非常高的分類器:
訓練集是由0~9十個數字組成的,每個數字有兩百個左右的訓練樣本。所有的訓練樣本統一被處理為一個32*32的0/1矩陣,其中所有值為1的連通區域構成了形象上的數字,如下所示:
所以,在構造我們的測試集的時候,所有的手寫數字圖片必須被處理為這樣的格式才能夠使得分類算法正確地進行,這也是KNN的局限所在。
二、算法實現
1、構建測試集
上面已經提到,要想算法正確地進行,測試集的樣式應該和訓練集相同,也就是說我們要把一張包含有手寫數字的圖像,轉換為一個32*32的0/1點陣。
測試集使用我自己手寫的10個數字:
這里存在一個非常大的問題:這個數據集的作者是土耳其人,他們書寫數字的習慣和我們有諸多不同,比如上面的數字4和數字8,下面這樣子的數字就無法識別:4/8。哈哈,也就是說它連印刷體都無法識別,這是這個訓練集的一大缺陷之一。
1)圖像預處理
圖像預處理的過程是一個數字圖像處理(DIP)的過程,觀察上面的10個數字,可以發現每張圖像的大小/對比度的差距都非常大,所以圖像預處理應該消除這些差距。
第一步是進行圖像的放大/縮小。由于我們很難產生一個小于32*32像素的手寫數字圖像,所以這里主要是縮小圖像:
1 import cv2 2 def readImage(imagePath): 3 image = cv2.imread(imagePath,cv2.IMREAD_GRAYSCALE) 4 image = cv2.resize(image,(32,32),interpolation = cv2.INTER_AREA) 5 return image
這里我沒有去實現圖像重采樣的方法(實現在后面的博客會寫),而是采用的OpenCV,通過area來確定取樣點的灰度值(推薦用bicubic interpolation,對應的插入函數應該是INTER_CUBIC),在讀入圖像的時候讀入方式位IMRAD_GRAYSCALE,因為我們需要的是識別手寫字符,灰度圖對比彩色圖能更好的突出重點。
進行圖像的縮放是不夠的,因為觀察上面的圖片可以發現:拍攝環境對于對比度的影響非常大,所以我們應該突出深色區域(數字部分),來保證后面的工作順利進行,這里采用的是伽馬變換(也可以采用對數變換):
1 def imageGamma(image): 2 for i in range(32): 3 for j in range(32): 4 image[i][j]=3*pow(image[i][j],0.8) 5 return image
2)圖像二值化
縮小/放大后的圖像已經是一個32*32的圖像了,下一步則是將非數字區域填充0,數字區域填充1,這里我采用的是閾值二值化處理:
def imageThreshold(image):
ret,image = cv2.threshold(image,150,255,cv2.THRESH_BINARY)
return image
經過二值化處理,數字部分的灰度值應該為0,而非數字部分的連通區域的灰度值應該為255,如下所示:
3)去噪
圖像去噪的方式有很多種,這里建立使用自適應中值濾波器進行降噪,因為我們的圖像在傳輸過程中可能出現若干的椒鹽噪聲,這個噪聲在上述的二值化處理中有時候是非常棘手的。
到目前為止,一副手機攝像的手寫數字圖像就可以轉換為一個32*32的二值圖像。
4)生成訓練樣本
如何將這個32*32的二值圖像轉換為0/1圖像,這個處理非常簡單:
1 def imageProcess(image):
2 with open(r'F:UsersyangPycharmProjectsOCR_KNN estDigits6_0.txt','w+') as file:
3 for i in range(32):
4 for j in range(32):
5 if image[i][j] == 255:
6 file.write('0')
7 else:
8 file.write('1')
9 file.writelines('
')
這里我的代碼在掃描這個圖像的同時,將其保存為一個訓練樣本,命名和訓練集的明明要求一樣為N_M.txt,其中N代表這個訓練樣本的實際分類是什么數字,M代表這是這個數字的第幾個樣本。這里對圖像進行灰度變換已經是多此一舉了,我所需要的是0/1矩陣而非一個0/1圖像,所以在掃描過程中一并生成訓練樣本更加省時直觀。
5)形成訓練集
上面的示例只是生成一個圖像的訓練樣本的,而實際上我們往往需要一次性生成一個訓練集,這就要求這個圖像預處理、二值化并且生成0/1矩陣的過程是自動的:
1 from os import listdir
2 def imProcess(imagePath):
3 testDigits = listdir(imagePath)
4 for i in range(len(testDigits)):
5 imageName = testDigits[i]#圖像命名格式為N_M.png,NM含義見4)生成訓練樣本
6 #imageClass = int((imageName.split('.')[0]).split('_')[0])#這個圖像的數字是多少
7 image = cv2.imread(imageName,cv2.IMREAD_GRAYSCALE)
8 image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_AREA)
9 ret, image = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY)
10 with open(r'F:UsersyangPycharmProjectsOCR_KNN estDigits\'+imageName.split('.')[0]+'.txt','w+') as file:
11 for i in range(32):
12 for j in range(32):
13 if image[i][j] == 255:
14 file.write('0')
15 else:
16 file.write('1')
17 file.writelines('
')
這個函數將imagePath文件夾中所有的N_M命名的手寫數字圖像讀取并經過預處理、二值化、最后保存為對應的0/1矩陣,命名為N_M.txt,這就構成一個訓練集了。
2、構建分類器
分類器使用上一節的分類器(classify):
1 def classify(vector,dataSet,labels,k):
2 distance = sqrt(abs(((tile(vec,(dataSet.shape[0],1)) - dataSet) ** 2).sum(axis = 1))); #計算距離
3 sortedDistance = distance.argsort()
4 dict={}
5 for i in range(k):
6 label = labels[sortedDistance[i]]
7 if not label in dict:
8 dict[label] = 1
9 else:
10 dict[label]+=1
11 sortedDict = sorted(dict,key = operator.itemgetter(1),reverse = True)
12 return sortedDict[0][0]
13
14 def dict2list(dic:dict):#將字典轉換為list類型
15 keys=dic.keys()
16 values=dic.values()
17 lst=[(key, value)for key,value in zip(keys,values)]
18 return lst
distance的計算和dict2list函數的詳解在上一節,戳上面的classify既可以跳轉過去。
分類器已經構建完成,下一步是提取每一個測試樣本,提取訓練集,提取label的過程:(這個過程大部分用的是《機器學習實戰》中的代碼,對于難以理解的代碼在下文中做了解釋:)
1)讀取0/1矩陣文件:
1 def img2vector(filename): 2 returnvec = numpy.zeros((1,1024)) 3 file = open(filename) 4 for i in range(32): 5 line = file.readline() 6 for j in range(32): 7 returnvec[0,32*i+j] = int(line[j]) 8 return returnvec
這里要注意:構造一個32*32的全零矩陣的時候,應該是numpy.zeros((1,1024)),雙層括號!雙層括號!雙層括號!代表構造的是一個二維矩陣!
2)讀取訓練集和測試集并求解準確率:
1 def handWritingClassifyTest():
2 labels=[]
3 trainingFile = listdir(r'F:UsersyangPycharmProjectsOCR_KNN rainingDigits')
4 m = len(trainingFile)
5 trainingMat = numpy.zeros((m,1024))
6 for i in range(m):
7 file = trainingFile[i]
8 filestr = file.strip('.')[0]
9 classnum = int(filestr.strip('_')[0])
10 labels.append(classnum)
11 trainingMat[i,:] = img2vector('trainingDigits/%s' % file)
12 testFileList = listdir(r'F:UsersyangPycharmProjectsOCR_KNN estDigits')
13 error = 0.0
14 testnum = len(testFileList)
15 for i in range(testnum):
16 file_test = testFileList[i]
17 filestr_test = file_test.strip('.')[0]
18 classnum_test = int(filestr_test.strip('_')[0])
19 vector_test = img2vector('testDigits/%s'%file_test)
20 result = classify(vector_test,trainingMat,labels,1)
21 if(result!=classnum_test):error+=1.0
22 print("準確率:%f"%(1.0-(error/float(testnum))))
代碼其實沒有很難懂的地方,主要任務就是讀取文件,通過img2vctor函數轉換為矩陣,還有切割文件名獲取該測試樣本的類別和該訓練樣本的類別,通過對比獲得準確率。
3、使用分類器
現在為止,我們的分類器已經構建完成,下面就是測試和使用階段:
1)測試《機器學習實戰》中給出的訓練集:
2)測試手寫訓練集:
emmm果然學不出來大佬寫字,附上幾張無法識別的0/1數字矩陣:(0,4,6無法識別的原因是比劃太細哈哈,8無法識別的原因……太端正了吧)
4、完整代碼:
1 from os import listdir
2 import numpy
3 import operator
4 import cv2
5
6 def imProcess(imagePath):
7 testDigits = listdir(imagePath)
8 for i in range(len(testDigits)):
9 imageName = testDigits[i]#圖像命名格式為N_M.png,NM含義見4)生成訓練樣本
10 #imageClass = int((imageName.split('.')[0]).split('_')[0])#這個圖像的數字是多少
11 image = cv2.imread(imageName,cv2.IMREAD_GRAYSCALE)
12 image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_AREA)
13 ret, image = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY)
14 with open(r'F:UsersyangPycharmProjectsOCR_KNN estDigits\'+imageName.split('.')[0]+'.txt','w+') as file:
15 for i in range(32):
16 for j in range(32):
17 if image[i][j] == 255:
18 file.write('0')
19 else:
20 file.write('1')
21 file.writelines('
')
22
23 def img2vector(filename):
24 returnvec = numpy.zeros((1,1024))
25 file = open(filename)
26 for i in range(32):
27 line = file.readline()
28 for j in range(32):
29 returnvec[0,32*i+j] = int(line[j])
30 return returnvec
31
32 def handWritingClassifyTest():
33 labels=[]
34 trainingFile = listdir(r'F:UsersyangPycharmProjectsOCR_KNN rainingDigits')
35 m = len(trainingFile)
36 trainingMat = numpy.zeros((m,1024))
37 for i in range(m):
38 file = trainingFile[i]
39 filestr = file.strip('.')[0]
40 classnum = int(filestr.strip('_')[0])
41 labels.append(classnum)
42 trainingMat[i,:] = img2vector('trainingDigits/%s' % file)
43 testFileList = listdir(r'F:UsersyangPycharmProjectsOCR_KNN estDigits')
44 error = 0.0
45 testnum = len(testFileList)
46 for i in range(testnum):
47 file_test = testFileList[i]
48 filestr_test = file_test.strip('.')[0]
49 classnum_test = int(filestr_test.strip('_')[0])
50 vector_test = img2vector('testDigits/%s'%file_test)
51 result = classify(vector_test,trainingMat,labels,1)
52 if(result!=classnum_test):error+=1.0
53 print("準確率:%f"%(1.0-(error/float(testnum))))
54
55 def classify(inX,dataSet,labels,k):
56 size = dataSet.shape[0]
57 distance = (((numpy.tile(inX,(size,1))-dataSet)**2).sum(axis=1))**0.5
58 sortedDistance = distance.argsort()
59 count = {}
60 for i in range(k):
61 label = labels[sortedDistance[i]]
62 count[label]=count.get(label,0)+1
63 sortedcount = sorted(dict2list(count),key=operator.itemgetter(1),reverse=True)
64 return sortedcount[0][0]
65
66 def dict2list(dic:dict):#將字典轉換為list類型
67 keys=dic.keys()
68 values=dic.values()
69 lst=[(key, value)for key,value in zip(keys,values)]
70 return lst
71
72 # def imProcess(image):
73 # image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_AREA)
74 # ret, image = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY)
75 # cv2.imshow('result',image)
76 # cv2.waitKey(0)
77 # with open(r'F:UsersyangPycharmProjectsOCR_KNN estDigits6_0.txt','w+') as file:
78 # for i in range(32):
79 # for j in range(32):
80 # if image[i][j] == 255:
81 # file.write('0')
82 # else:
83 # file.write('1')
84 # file.writelines('
')
85
86
87
88 # iamge = cv2.imread(r'C:UsersyangDesktop6.png',cv2.IMREAD_GRAYSCALE)
89 # image = imProcess(iamge)
90 imProcess(r'F:UsersyangPycharmProjectsOCR_KNN estDigits')
91 handWritingClassifyTest()
5、github:https://github.com/hahahaha1997/OCR
三、總結
KNN還是不適合用來做OCR的識別過程的,雖然《機器學習實戰》的作者提到這個系統是美國的郵件分揀系統實際運行的一個系統,但是它肯定無法高準確率地識別中國人寫的手寫文字就對了,畢竟中國有些地方的“9”還會寫成“p”的樣子的。這一節主要是將KNN拓展到實際運用中的,結合上一節的理論,KNN的執行效率還是太低了,比如這個系統,要識別一個手寫數字,它需要和所有的訓練樣本做距離計算,每個距離計算又有1024個(a-b)2,還有運行效率特別低下的sqrt(),如果是一個非常大的測試集,需要的時間就更加龐大,如果訓練集非常龐大,在將0/1矩陣讀入內存中的時候,內存開銷是非常巨大的,所以整個程序可能會非常耗時費力。不過KNN仍舊是一個精度非常高的算法,并且也是機器學習分類算法中最簡單的算法之一。下一節將帶來機器學習經典算法——ID3決策樹。轉載注明出處哦:https://www.cnblogs.com/DawnSwallow/p/9440516.html
總結
以上是生活随笔為你收集整理的Python3实现机器学习经典算法(二)KNN实现简单OCR的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: SAP ABAP bcset激活时,关联
- 下一篇: 微信朋友圈怎么发1分钟的视频