python遥感影像分类代码_python,sklearn,svm,遥感数据分类,代码实例
python,sklearn,svm,遙感數據分類,代碼實例,數據,函數,精度,遙感,路徑
python,sklearn,svm,遙感數據分類,代碼實例
易采站長站,站長之家為您整理了python,sklearn,svm,遙感數據分類,代碼實例的相關內容。
@python,sklearn,svm,遙感數據分類,代碼實例
python_sklearn_svm遙感數據分類代碼實例
(1)svm原理簡述
支持向量機(Support Vector Machine,即SVM)是包括分類(Classification)、回歸(Regression)和異常檢測(Outlier Detection)等一系列監督學習算法的總稱。對于分類,SVM最初用于解決二分類問題,多分類問題可通過構建多個SVM分類器解決。SVM具有兩大特點:1.尋求最優分類邊界,即求解出能夠正確劃分訓練數據集并且幾何間隔最大的分離超平面,這是SVM的基本思想;2.基于核函數的擴維變換,即通過核函數的特征變換對線性不可分的原始數據集進行升維變換,使其線性可分。因此SVM最核心的過程是核函數和參數選擇。
(2)svm實現環境解析
設置中文輸出代碼兼容格式及引用的庫函數,用于精度評估的庫函數,以及svm參數尋優等。
下面展示一些內聯代碼片。
-*- coding: utf-8 -*-
#用于精度評價
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
#numpy引用
import numpy as np
#記錄運行時間
import datetime
#文件路徑操作
import os
#svm and best parameter select using grid search method
from sklearn import svm
from sklearn.model_selection import GridSearchCV
#scale the data to 0-1 用于數據歸一化
from sklearn import preprocessing
(3)svm函數參數尋優
SVM參數尋優的實現,有兩種常用方法,一種是網格搜索法(本文中的),另一種是使用libsvm工具通過交叉驗證實現(后面再寫,有興趣的可以留言)。
def grid_find(train_data_x,train_data_y):
# 10 is often helpful. Using a basis of 2, a finer.tuning can be achieved but at a much higher cost.
# logspace(a,b,N),base默認=10,把10的a次方到10的b次方區間分成N份。
C_range = np.logspace(-5, 9, 8, base=2)
# 如:C_range = 1/64,1/8,1/2,2,8,32,128,512
gamma_range = np.logspace(-15, 3, 10, base=2)
# 選擇linear線性核函數和rbf核函數
parameters = {'kernel': ('linear', 'rbf'), 'C': C_range, 'gamma': gamma_range}
svr = svm.SVC()
# n_jobs表示并行運算量,可加快程序運行結果。
# 此處選擇5折交叉驗證,10折交叉驗證也是常用的。
clf = GridSearchCV(svr, parameters, cv=5, n_jobs=4)
# 進行模型訓練
clf.fit(train_data_x, train_data_y)
print('最優c,g參數為:{0}'.format(clf.best_params_))
# 返回最優模型結果
svm_model = clf.best_estimator_
return svm_model
更多關于網格搜索法:
(4)數據讀取函數編寫(讀取txt格式的訓練與測試文件)
首先是讀取txt格式的訓練數據和測試數據的函數。
數據截圖如下,其中,前6列數據代表通過遙感影像感興趣區(roi)提取出的6個波段的灰度值,最后一列代表數據類別的標簽。
代碼如下,僅需輸入文件路徑即可:
def open_txt_film(filepath):
# open the film
if os.path.exists(filepath):
with open(filepath, mode='r') as f:
train_data_str = np.loadtxt(f, delimiter=' ')
print('訓練(以及測試)數據的行列數為{}'.format(train_data_str.shape))
return train_data_str
else:
print('輸入txt文件路徑錯誤,請重新輸入文件路徑')
(5)svm模型預測函數編寫
輸入模型與測試數據,輸出精度評估(包括混淆矩陣,制圖精度等等)。
def model_process(svm_model, test_data_x, test_data_y):
p_lable = svm_model.predict(test_data_x)
# 精確度為 生產者精度 召回率為 用戶精度
print('總體精度為 : {}'.format(accuracy_score(test_data_y, p_lable)))
print('混淆矩陣為 :\n {}'.format(confusion_matrix(test_data_y, p_lable)))
print('kappa系數為 :\n {}'.format(cohen_kappa_score(test_data_y, p_lable)))
matric = confusion_matrix(test_data_y, p_lable)
# output the accuracy of each category。由于類別標簽是從1開始的,因此明確數據中最大值,即可知道有多少類
for category in range(np.max(test_data_y)):
# add 0.0 to keep the float type of output
precise = (matric[category, category] + 0.0) / np.sum(matric[category, :])
recall = (matric[category, category] + 0.0) / np.sum(matric[:, category])
f1_score = 2 * (precise * recall) / (recall + precise)
print(
'類別{}的生產者、制圖(recall)精度為{:.4} 用戶(precision)精度為{:.4} F1 score 為{:.4} '.format(category + 1, precise, recall, f1_score))
(6)主函數編寫
主函數主要負責:讀取數據,預處理數據,以及參數尋優、模型訓練和模型預測。
針對不同的數據集,每次使用,僅僅需要修改訓練與測試數據的路徑即可。
def main():
# read the train data from txt film
train_file_path = r'E:\CSDN\data1\train.txt'
train_data = open_txt_film(train_file_path)
# read the predict data from txt film
test_file_path = r'E:\CSDN\data1\test.txt'
test_data = open_txt_film(test_file_path)
# data normalization for svm training and testing dataset
scaler = preprocessing.MinMaxScaler().fit(train_data[:, :-1])
train_data[:, :-1] = scaler.transform(train_data[:, :-1])
# keep the same scale of the train data
test_data[:, :-1] = scaler.transform(test_data[:, :-1])
# conversion the type of data,and the label's dimension to 1-d
train_data_y = train_data[:, -1:].astype('int')
train_data_y = train_data_y.reshape(len(train_data_y))
train_data_x = train_data[:, :-1] # 取出測試數據灰度值和標簽值,并將2維標簽轉為1維
test_data_x = test_data[:, :-1] test_data_y = test_data[:, -1:].astype('int')
test_data_y = test_data_y.reshape(len(test_data_y))
model = grid_find(train_data_x,train_data_y)
# 模型預測
model_process(model, test_data_x, test_data_y)
(7)調用主函數
這里新增了幾行代碼用于記錄程序運行時間。
if __name__ == "__main__":
# remember the beginning time of the program
start_time = datetime.datetime.now()
print("start...%s" % start_time)
main()
# record the running time of program with the unit of minutes
end_time = datetime.datetime.now()
last_time = (end_time - start_time).seconds / 60
print("The program is last %s" % last_time + " minutes")
# print("The program is last {} seconds".format(last_time))
(8)訓練數據與測試數據實例下載地址
數據在作者的github倉庫下,共兩個文件(train.txt 和 test.txt) 。
[下載鏈接]: (https://github.com/sunzhihu123/sunzhihu123.github.io)
倉庫下點擊下載即可,如圖:
==本文優點:僅有兩個輸入,一個是訓練數據的路徑,一個是測試數據的路徑,輕松上手;并且以遙感圖像數據為例。另外github將會整體上傳源碼哦~
作者:huhu_xq以上就是關于對python,sklearn,svm,遙感數據分類,代碼實例的詳細介紹。歡迎大家對python,sklearn,svm,遙感數據分類,代碼實例內容提出寶貴意見
總結
以上是生活随笔為你收集整理的python遥感影像分类代码_python,sklearn,svm,遥感数据分类,代码实例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: win8专业版如何激活
- 下一篇: 怎么利用flash绘制一只漂亮的玻璃蝴蝶