#數(shù)據(jù)預(yù)處理
#將CakeType的值映射到0、1,方便后續(xù)模型運(yùn)算
import numpy as np
label = np.where(data['CakeType']=='muffin',0,1)
print(label)#[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]
x = data[['Sugar','Butter']]
#print(x)#SVM實(shí)例化
from sklearn.svm import SVC
#SVC指Support Vector Classifier
svc = SVC(kernel='linear',C=1)
'''
SVC參數(shù)說明:
C:懲罰系數(shù),即當(dāng)分類器錯(cuò)誤地將A類樣本劃分為B類了,我們將給予分類器多大的懲罰。當(dāng)我們給與非常大的懲罰,即C的值設(shè)置的很大,那么分類器會(huì)變得非常精準(zhǔn),但是,會(huì)產(chǎn)生過擬合問題。
kernel:核函數(shù),如果使用一條直線就可以將屬于不同類別的樣本點(diǎn)全部劃分開,那么我們使用kernel='linear',
如果不能線性劃分開,尤其是當(dāng)數(shù)據(jù)維度很多時(shí),一般很難找到一條合適的線將不同的類別的樣本劃分開,那么就嘗試使用高斯核函數(shù)(也稱為徑向基核函數(shù)-rbf)、多項(xiàng)式核函數(shù)(poly)
'''
svc.fit(X=x,y=label)#根據(jù)擬合結(jié)果,找出超平面
w = svc.coef_[0]
a = -w[0]/w[1]#超平面的斜率,也是邊界線的斜率
xx = np.linspace(5,30)#生成5~30之間的50個(gè)數(shù)
#print(xx)
'''
[ 5. 5.51020408 6.02040816 6.53061224 7.04081633 7.551020418.06122449 8.57142857 9.08163265 9.59183673 10.10204082 10.612244911.12244898 11.63265306 12.14285714 12.65306122 13.16326531 13.6734693914.18367347 14.69387755 15.20408163 15.71428571 16.2244898 16.7346938817.24489796 17.75510204 18.26530612 18.7755102 19.28571429 19.7959183720.30612245 20.81632653 21.32653061 21.83673469 22.34693878 22.8571428623.36734694 23.87755102 24.3877551 24.89795918 25.40816327 25.9183673526.42857143 26.93877551 27.44897959 27.95918367 28.46938776 28.9795918429.48979592 30. ]'''
yy = a * xx - (svc.intercept_[0])/w[1]#根據(jù)超平面,找到超平面的兩條邊界線
b = svc.support_vectors_[0]
yy_down = a * xx + (b[1]-a*b[0])
b = svc.support_vectors_[-1]
yy_up = a * xx + (b[1]-a*b[0])#繪制超平面和邊界線
#(1)繪制樣本點(diǎn)的散點(diǎn)圖
sns.lmplot(data=data,x='Sugar',y='Butter',hue='CakeType',palette='Set1',fit_reg=False,scatter_kws={'s':150})
#(2)向散點(diǎn)圖添加超平面
from matplotlib import pyplot as plt
plt.plot(xx,yy,linewidth=4,color='black')#(3)向散點(diǎn)圖添加邊界線
plt.plot(xx,yy_down,linewidth=2,color='blue',linestyle='--')
plt.plot(xx,yy_up,linewidth=2,color='blue',linestyle='--')
效果如下:
# 調(diào)整參數(shù)C,看看會(huì)有什么不同?
svc = SVC(kernel='linear',C=0.001)
svc.fit(X=x,y=label)#根據(jù)擬合結(jié)果,找出超平面
w = svc.coef_[0]
a = -w[0]/w[1]
xx = np.linspace(5,30)
yy = a * xx - (svc.intercept_[0])/w[1]#根據(jù)超平面,找到超平面的兩條邊界線
b = svc.support_vectors_[0]
yy_down = a * xx + (b[1]-a*b[0])
b = svc.support_vectors_[-1]
yy_up = a * xx + (b[1]-a*b[0])#繪制超平面和邊界線
#(1)繪制樣本點(diǎn)的散點(diǎn)圖
sns.lmplot(data=data,x='Sugar',y='Butter',hue='CakeType',palette='Set1',fit_reg=False,scatter_kws={'s':150})
#(2)向散點(diǎn)圖添加超平面
from matplotlib import pyplot as plt
plt.plot(xx,yy,linewidth=4,color='black')
#(3)向散點(diǎn)圖添加邊界線
plt.plot(xx,yy_down,linewidth=2,color='blue',linestyle='--')
plt.plot(xx,yy_up,linewidth=2,color='blue',linestyle='--')