【scikit-learn基础】--『监督学习』之 支持向量机分类
支持向量機也是一種既可以處理分類問題,也可以處理回歸問題的算法。
關于支持向量機在回歸問題上的應用,請參考:TODO
支持向量機分類廣泛應用于圖像識別、文本分類、生物信息學(例如基因分類)、手寫數字識別等領域。
1. 算法概述
支持向量機的主要思想是找到一個超平面,將不同類別的樣本最大化地分隔開。
超平面的位置由支持向量決定,它們是離分隔邊界最近的數據點。
對于二分類問題,SVM尋找一個超平面,使得正例和支持向量到超平面的距離之和等于反例和支持向量到超平面的距離之和。
如果這個等式不成立,SVM將尋找一個更遠離等式中不利樣本的超平面。
下面的示例,演示了支持向量機分類算法在圖像識別上的應用。
2. 創建樣本數據
這次的樣本使用的是scikit-learn自帶的手寫數字數據集。
import matplotlib.pyplot as plt
from sklearn import datasets
# 加載手寫數據集
data = datasets.load_digits()
_, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 6))
for ax, image, label in zip(np.append(axes[0], axes[1]), data.images, data.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("目標值: {}".format(label))
這里顯示了其中的幾個手寫數字,這個數據集總共有大約1700多個手寫數字。
3. 模型訓練
樣本數據中,手寫數字的圖片存儲為一個 8x8 的二維數組。
比如:
data.images[0]
# 運行結果
array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
[ 0., 0., 13., 15., 10., 15., 5., 0.],
[ 0., 3., 15., 2., 0., 11., 8., 0.],
[ 0., 4., 12., 0., 0., 8., 8., 0.],
[ 0., 5., 8., 0., 0., 9., 8., 0.],
[ 0., 4., 11., 0., 1., 12., 7., 0.],
[ 0., 2., 14., 5., 10., 12., 0., 0.],
[ 0., 0., 6., 13., 10., 0., 0., 0.]])
所以,在分割訓練集和測試集之前,我們需要先將手寫數字的的存儲格式從 8x8 的二維數組轉換為 64x1 的一維數組。
from sklearn.model_selection import train_test_split
n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target
# 分割訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
按照9:1的比例來劃分訓練集和測試集。
然后用scikit-learn中的SVC模型來訓練樣本:
from sklearn.svm import SVC
# 定義
reg = SVC()
# 訓練模型
reg.fit(X_train, y_train)
模型的訓練效果:
# 在測試集上進行預測
y_pred = reg.predict(X_test)
correct_pred = np.sum(y_pred == y_test)
print("預測正確率:{:.2f}%".format(correct_pred / len(y_pred) * 100))
# 運行效果
預測正確率:98.89%
正確率非常高,下面我們看看沒識別出來的手寫數字是哪些。
wrong_pred = []
for i in range(len(y_pred)):
if y_pred[i] != y_test[i]:
wrong_pred.append(i)
print(wrong_pred)
# 運行效果
[156, 158]
在測試集中,只有兩個手寫數字識別錯了。
我面看看識別錯的2個手寫數字是什么樣的:
_, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
for i in range(2):
idx = wrong_pred[i]
image = X_test[idx].reshape(8, 8)
axes[i].set_axis_off()
axes[i].imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
axes[i].set_title("預測值({}) 目標值({})".format(y_pred[idx], y_test[idx]))
可以看出,即使人眼去識別,這兩個手寫數字也不太容易識別。
4. 總結
支持向量機分類算法的優勢有:
- 有效處理高維數據:對高維數據非常有效,即使在數據維度超過樣本數量的情況下也能工作得很好。
- 高效:只使用一部分訓練數據(即支持向量)來做決策,這使得算法更加內存高效。
- 穩定性較好:由于其決策邊界取決于支持向量而不是所有的數據點,因此模型的穩定性較好,對噪聲和異常值的敏感度較低。
它的劣勢主要有:
- 對參數和核函數敏感:性能高度依賴于參數設置(如懲罰參數C和核函數的選擇)。如果參數選擇不當,可能會導致過擬合或欠擬合。
- 難以解釋:不像決策樹那樣直觀,難以理解和解釋。
- 處理大規模數據時速度較慢:訓練過程涉及到二次規劃問題,需要使用復雜的優化算法,因此在處理大規模數據時可能較慢。
總結
以上是生活随笔為你收集整理的【scikit-learn基础】--『监督学习』之 支持向量机分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Terraform 的开源替代:Open
- 下一篇: 聊聊ChatGLM-6B源码分析(二)