Python,OpenCV基于支持向量机SVM的手写数字OCR
Python,OpenCV基于支持向量機SVM的手寫數字OCR
- 1. 效果圖
- 2. SVM及原理
- 2. 源碼
- 2.1 SVM的手寫數字OCR
- 2.2 非線性SVM
- 參考
上一節介紹了基于KNN的手寫數字OCR+字母OCR,這一節將介紹基于支持向量機SVM的手寫數字OCR。
1. 效果圖
簡單線性向量機訓練效果圖如下:
圖中有4個點,3個趨于白色點,一個灰黑色點,可以看到分割線的決策邊界很明顯。
非線性向量機訓練數據效果圖如下:
下圖中綠色、藍色點雜糅在一起,中間的決策邊界是非線性的,但可以近似為線性。邊界有灰色圓圈的點是 支持向量,依賴這些少量的數據就可以找到 決策邊界。
2. SVM及原理
支持向量機SVM(Supported Vector Machines)
要了解SVM,首先需要了解線性可分數據及線性不可分數據,簡單來說,就是在平面或多維有一堆點進行分類,能否用一根線分隔以分類彼此。
-
線性可分數據
KNN需要計算測試數據到所有點的距離,當數據量比較大的時候,需要較大的內存來存儲。 可以有另一種思路:找到一條線 f(x)=ax_1+bx_2+c ,它將數據分為兩個區域。當得到一個新的 test_data X 時,只需將其替換為 f(x)。如果 f(X) > 0,則屬于藍色組,否則屬于紅色組。
稱這條線為 決策邊界,它非常簡單且節省內存。 這種可以用直線(或更高維的超平面)一分為二的數據稱為 線性可分數據。
-
低維空間中的非線性可分離數據在高維空間中變為線性可分離的可能性更大。
在上圖中可以看到很多這樣的線條是可能的。要拿哪一個?非常直觀,可以說這條線應該盡可能遠離所有點。
走最遠的線路將提供更多的抗噪能力。所以SVM所做的就是找到一條到訓練樣本最小距離最大的直線(或超平面)。
- 要找到這個決策邊界,并不需要所有數據,只需要那些靠近相反群體的數據。
在該圖像中,它們是一個藍色實心圓圈和兩個紅色實心方塊。我們可以稱它們為支持向量,穿過它們的線稱為支持平面。它們足以找到決策邊界。
- 權重向量決定決策邊界的方向,而偏置點決定其位置。
2. 源碼
2.1 SVM的手寫數字OCR
# 使用SVM進行手寫數據OCR# 在KNN中直接使用像素強度作為特征向量。
# 在SVM中使用方向梯度直方圖(HOG Histogram of Oriented Gradients)作為特征向量。
# 在這里,使用二階矩對圖像進行反扭曲。
import cv2
import numpy as npSZ = 20
bin_n = 16 # Number of binssvm_params = dict(kernel_type=cv2.ml.SVM_LINEAR,svm_type=cv2.ml.SVM_C_SVC,C=2.67, gamma=5.383)affine_flags = cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR# 左圖像是原始圖像,右圖像是傾斜圖像。
def deskew(img):m = cv2.moments(img)if abs(m['mu02']) < 1e-2:return img.copy()skew = m['mu11'] / m['mu02']M = np.float32([[1, skew, -0.5 * SZ * skew], [0, 1, 0]])img = cv2.warpAffine(img, M, (SZ, SZ), flags=affine_flags)return img# (HOG Histogram of Oriented Gradients)方向梯度直方圖
def hog(img):gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)mag, ang = cv2.cartToPolar(gx, gy)# 量化 (0...16)的binvaluesbins = np.int32(bin_n * ang / (2 * np.pi))# 分成四個子塊bin_cells = bins[:10, :10], bins[10:, :10], bins[:10, 10:], bins[10:, 10:]mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:]hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]hist = np.hstack(hists)return histimg = cv2.imread('images/digits.png', 0)
print(img.shape) # (1000,2000)cells = [np.hsplit(row, 100) for row in np.vsplit(img, 50)]
print(len(cells)) # 50*100# 一半數據用于訓練,一半用于測試(前50列,后50列)
train_cells = [i[:50] for i in cells]
test_cells = [i[50:] for i in cells]# cv2.imshow("img", train_cells[0][0])
# cv2.imshow("deskew", deskew(train_cells[0][0]))
# cv2.waitKey(0)# 訓練數據
deskewed = [list(map(deskew, row)) for row in train_cells]
hogdata = [list(map(hog, row)) for row in deskewed]trainData = np.float32(hogdata).reshape(-1, 64)
responses = np.repeat(np.arange(10), 250)[:, np.newaxis]
print('trainData: ', type(trainData), len(trainData))
print('responses: ', type(responses), responses.shape, len(responses))print(responses[0])svm = cv2.ml.SVM_create()
svm.setGamma(svm_params['gamma'])
svm.setC(svm_params['C'])
svm.setKernel(cv2.ml.SVM_LINEAR)
svm.setType(cv2.ml.SVM_C_SVC)
svm.train(trainData, cv2.ml.ROW_SAMPLE, responses)# 把訓練的數據及模型保存下來
svm.save('images/svm_data.dat')# 測試數據
deskewed = [list(map(deskew, row)) for row in test_cells]
hogdata = [list(map(hog, row)) for row in deskewed]
testData = np.float32(hogdata).reshape(-1, bin_n * 4)
result = svm.predict(testData)[1]print('result: ', type(result))
print('responses: ', type(responses))# 檢查準確度
mask = result == responses
correct = np.count_nonzero(mask)
print('correct: ', correct)# SVM得到93.8%的準確度,比KNN的91.76%要高一些
print(correct * 100.0 / len(list(result)))
2.2 非線性SVM
from __future__ import print_functionimport random as rngimport cv2 as cv
import numpy as npNTRAINING_SAMPLES = 100 # Number of training samples per class
FRAC_LINEAR_SEP = 0.9 # Fraction of samples which compose the linear separable part# 可視化窗口大小
WIDTH = 512
HEIGHT = 512
I = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)# 隨機生成訓練數據
trainData = np.empty((2 * NTRAINING_SAMPLES, 2), dtype=np.float32)
labels = np.empty((2 * NTRAINING_SAMPLES, 1), dtype=np.int32)rng.seed(100) # 隨機生成分類標簽# 為訓練數據設置線性分離區
# Set up the linearly separable part of the training data
nLinearSamples = int(FRAC_LINEAR_SEP * NTRAINING_SAMPLES)# 為分類1生成隨機點
trainClass = trainData[0:nLinearSamples, :]
# x在[0,0.4]
c = trainClass[:, 0:1]
c[:] = np.random.uniform(0.0, 0.4 * WIDTH, c.shape)
# y在[0, 1)
c = trainClass[:, 1:2]
c[:] = np.random.uniform(0.0, HEIGHT, c.shape)# 為分類2生成隨機點
trainClass = trainData[2 * NTRAINING_SAMPLES - nLinearSamples:2 * NTRAINING_SAMPLES, :]
# x在 [0.6, 1]
c = trainClass[:, 0:1]
c[:] = np.random.uniform(0.6 * WIDTH, WIDTH, c.shape)
# y在 [0, 1)
c = trainClass[:, 1:2]
c[:] = np.random.uniform(0.0, HEIGHT, c.shape)# 為測試數據集的分類1,2分別生成隨機點
trainClass = trainData[nLinearSamples:2 * NTRAINING_SAMPLES - nLinearSamples, :]
# x在[0.4,0.6]
c = trainClass[:, 0:1]
c[:] = np.random.uniform(0.4 * WIDTH, 0.6 * WIDTH, c.shape)
# y在[0,1]
c = trainClass[:, 1:2]
c[:] = np.random.uniform(0.0, HEIGHT, c.shape)# 設置分類標簽1及2
labels[0:NTRAINING_SAMPLES, :] = 1 # 分類1
labels[NTRAINING_SAMPLES:2 * NTRAINING_SAMPLES, :] = 2 # 分類2# 開始訓練,首先設置支持向量機SVM參數
print('Starting training process')
# 初始化
svm = cv.ml.SVM_create()
svm.setType(cv.ml.SVM_C_SVC)
svm.setC(0.1)
svm.setKernel(cv.ml.SVM_LINEAR)
svm.setTermCriteria((cv.TERM_CRITERIA_MAX_ITER, int(1e7), 1e-6))# 訓練SVM
svm.train(trainData, cv.ml.ROW_SAMPLE, labels)# 結束訓練
print('Finished training process')# 展示決策區域(繪制藍色,綠色) 分類1為綠色,分類2為藍色
green = (0, 100, 0)
blue = (100, 0, 0)
for i in range(I.shape[0]):for j in range(I.shape[1]):sampleMat = np.matrix([[j, i]], dtype=np.float32)response = svm.predict(sampleMat)[1]if response == 1:I[i, j] = greenelif response == 2:I[i, j] = blue# 展示測試數據
thick = -1# 分類1 綠色
for i in range(NTRAINING_SAMPLES):px = trainData[i, 0]py = trainData[i, 1]cv.circle(I, (px, py), 3, (0, 255, 0), thick)# 分類2 藍色
for i in range(NTRAINING_SAMPLES, 2 * NTRAINING_SAMPLES):px = trainData[i, 0]py = trainData[i, 1]cv.circle(I, (px, py), 3, (255, 0, 0), thick)# 展示支持向量
thick = 2
sv = svm.getUncompressedSupportVectors()for i in range(sv.shape[0]):cv.circle(I, (sv[i, 0], sv[i, 1]), 6, (128, 128, 128), thick)cv.imwrite('non_linear_svms_result.png', I) # 保存圖片
cv.imshow('SVM for Non-Linear Training Data', I) # 展示圖片結果
cv.waitKey()
參考
- https://docs.opencv.org/3.0-beta/doc/py_tutorials/py_ml/py_svm/py_svm_basics/py_svm_basics.html#svm-understanding
總結
以上是生活随笔為你收集整理的Python,OpenCV基于支持向量机SVM的手写数字OCR的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《春雪》是谁的作品?
- 下一篇: 鼋头渚乘几路车可以到