sklearn实战之构建SVM多分类器
生活随笔
收集整理的這篇文章主要介紹了
sklearn实战之构建SVM多分类器
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
利用sklearn庫構(gòu)建SVM分類器十分簡單,因?yàn)檫@個(gè)庫已經(jīng)封裝好了,只用調(diào)用相應(yīng)的函數(shù)即可。
# -*- coding: utf-8 -*- """ Created on Fri Nov 23 18:44:37 2018@author: 13260 """import os import numpy as np import matplotlib.pyplot as plt from itertools import cycle from sklearn import svm, metrics, preprocessing from sklearn.metrics import roc_curve, auc, classification_report from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelBinarizer from sklearn.multiclass import OneVsRestClassifier from sklearn.externals import joblib from scipy import interp# 加載圖像特征及標(biāo)簽 ''' def read_features(filedir):file_list = os.listdir(filedir)X = []tmp_y = os.listdir("F:/shiyan/TensorFlow/retrain/data/train")# print(len(y))y = []for file in file_list:tmp_file = filedir + "/" + filetmp = np.loadtxt(tmp_file,dtype=str)# np格式轉(zhuǎn)換feature = tmp.astype(np.float)X.append(feature)old_filename = file[:-3].split("_")filename = "_".join(old_filename[:-1])# tmp_filename = filter(str.isalpha,file[:-3])# print(filename)y.append(tmp_y.index(filename))# 特征數(shù)據(jù)保存到txt文件的格式是str,因此在進(jìn)行運(yùn)算時(shí)應(yīng)進(jìn)行格式轉(zhuǎn)換 tmp = "F:/python/airplane_001.txt"tmp_data = np.loadtxt(tmp,dtype=str)res = tmp_data.astype(np.float)X.append(res) ''' # 加載特征和標(biāo)簽文件 def load_features_and_labels(features_path,labels_path):features = np.load(features_path)labels = np.load(labels_path)print("[INFO] Feature and label file have been loaded !")return features,labelsdef train_and_test_model(feature,label):# X_scaled = preprocessing.scale(X)# print(y)# y = label_binarize(y,classes=list(range(45)))label_list = os.listdir("F:/shiyan/TensorFlow/retrain/data/train")# print(label.shape)# print(label)label = LabelBinarizer().fit_transform(label)# print(label)# print(label[:45])# print(label.shape[1])# print(y[:45])# 訓(xùn)練模型并進(jìn)行預(yù)測(cè)random_state = np.random.RandomState(0)n_samples, n_feature = feature.shape# 隨機(jī)化數(shù)據(jù),并劃分訓(xùn)練集和測(cè)試集X_train, X_test, y_train, y_test = train_test_split(feature, label, test_size=.2,random_state=0)# 訓(xùn)練模型model = OneVsRestClassifier(svm.SVC(kernel='linear',probability=True,random_state=random_state))print("[INFO] Successfully initialize a new model !")print("[INFO] Training the model…… ")clt = model.fit(X_train,y_train)print("[INFO] Model training completed !")# 保存訓(xùn)練好的模型,下次使用時(shí)直接加載就可以了joblib.dump(clt,"F:/python/model/conv_19_80%.pkl")print("[INFO] Model has been saved !")'''# 加載保存的模型clt = joblib.load("F:/python/model/SVM.pkl")print("model has been loaded !")# y_train_pred = clt.predict(X_train)'''y_test_pred = clt.predict(X_test)ov_acc = metrics.accuracy_score(y_test_pred,y_test)print("overall accuracy: %f"%(ov_acc))print("===========================================")acc_for_each_class = metrics.precision_score(y_test,y_test_pred,average=None)print("acc_for_each_class:\n",acc_for_each_class)print("===========================================")avg_acc = np.mean(acc_for_each_class)print("average accuracy:%f"%(avg_acc))print("===========================================")classification_rep = classification_report(y_test,y_test_pred,target_names=label_list)print("classification report: \n",classification_rep)print("===========================================")#print("===========================================")confusion_matrix = metrics.confusion_matrix(y_test.argmax(axis=1),y_test_pred.argmax(axis=1))print("confusion metrix:\n",confusion_matrix)print("===========================================")# print("accuracy: %f"%(acc_r))print("[INFO] Successfully get SVM's classification overall accuracy ! ")if __name__ == "__main__":features_path = "F:/python/features/DenseNet/densenet_fv_flatten.npy"labels_path = "F:/python/features/VGG19/VGG19_labels.npy"feature,label = load_features_and_labels(features_path,labels_path) train_and_test_model(feature,label)總結(jié)
以上是生活随笔為你收集整理的sklearn实战之构建SVM多分类器的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python快速检测视频跳过帧_使用Py
- 下一篇: python线性回归算法简介_Pytho