LAD线性判别分析鸢尾花预测
生活随笔
收集整理的這篇文章主要介紹了
LAD线性判别分析鸢尾花预测
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
LAD線性判別分析鳶尾花預測
文章目錄
- LAD線性判別分析鳶尾花預測
- 數學原理
- 代碼實現
數學原理
代碼實現
數據集下載鏈接:https://www.kaggle.com/uciml/iris/download
#!/usr/bin/env python # -*- coding: utf-8 -*- # @File : LDA.py # @Author: Gowi # @Date : 2021/3/25 # @Desc :import pandas as pd import numpy as np# 計算協方差矩陣 def Sigma(Iris, u):s = np.zeros((4, 4))for i in range(30):a = Iris[i, :] - ua = np.array([a])s = s + np.dot(a.T, a)return sdef predict(Iris_test):num1, num2, num3 = 0, 0, 0for i in range(20):acc1, acc2, acc3 = 0, 0, 0U12_test = np.dot(W12.T, Iris_test[i])U13_test = np.dot(W13.T, Iris_test[i])U23_test = np.dot(W23.T, Iris_test[i])if np.abs(U12_test - U12_1) < np.abs(U12_test - U12_2):acc1 += 1else:acc2 += 1if np.abs(U13_test - U13_1) < np.abs(U13_test - U13_2):acc1 += 1else:acc3 += 1if np.abs(U23_test - U23_1) < np.abs(U23_test - U23_2):acc2 += 1else:acc3 += 1acc = max(acc1, acc2, acc3)if acc == acc1:num1 += 1elif acc == acc2:num2 += 1else:num3 += 1return num1, num2, num3# 讀取數據集 df = pd.read_csv(r"Iris.csv", header=None) # 拆分數據集 Iris1_train = df.values[1:31, 1:5] Iris2_train = df.values[51:81, 1:5] Iris3_train = df.values[101:131, 1:5] Iris1_test = df.values[31:51, 1:5] Iris2_test = df.values[81:101, 1:5] Iris3_test = df.values[131:151, 1:5] # 鳶尾花的類別 Iris1_class = 'Iris-setosa' Iris2_class = 'Iris-versicolor' Iris3_class = 'Iris-virginica' # 轉換為float Iris1_train = Iris1_train.astype(np.float) Iris2_train = Iris2_train.astype(np.float) Iris3_train = Iris3_train.astype(np.float) Iris1_test = Iris1_test.astype(np.float) Iris2_test = Iris2_test.astype(np.float) Iris3_test = Iris3_test.astype(np.float) # 均值向量 u1 = np.mean(Iris1_train, axis=0) u2 = np.mean(Iris2_train, axis=0) u3 = np.mean(Iris3_train, axis=0) print("均值向量u1") print(u1) # 協方差矩陣 sigma1 = Sigma(Iris1_train, u1) sigma2 = Sigma(Iris2_train, u2) sigma3 = Sigma(Iris3_train, u3) print("類內散度矩陣sigma1") print(sigma1) # 類內散度矩陣 Sw12 = sigma1 + sigma2 Sw13 = sigma1 + sigma3 Sw23 = sigma2 + sigma2 print("類內散度矩陣Sw12") print(Sw12) # 類間散度矩陣 Sb12 = np.dot(np.array(u1 - u2), np.array(u1 - u2).T) Sb13 = np.dot(np.array(u1 - u3), np.array(u1 - u3).T) Sb23 = np.dot(np.array(u2 - u3), np.array(u2 - u3).T) # 斜率 W12 = np.dot(np.linalg.inv(Sw12), (u1 - u2)) W13 = np.dot(np.linalg.inv(Sw13), (u1 - u3)) W23 = np.dot(np.linalg.inv(Sw23), (u2 - u3)) print("斜率W12") print(W12) # 投影后的均值點 U12_1 = np.dot(W12.T, u1) U12_2 = np.dot(W12.T, u2) U13_1 = np.dot(W13.T, u1) U13_2 = np.dot(W13.T, u2) U23_1 = np.dot(W23.T, u2) U23_2 = np.dot(W23.T, u3) print("投影后的均值點U12_1") print(U12_1) # 預測Iris predict_1, _, _ = predict(Iris1_test) print("判斷為" + Iris1_class + "的個數") print(predict_1) _, predict_2, _ = predict(Iris2_test) print("判斷為" + Iris2_class + "的個數") print(predict_2) _, _, predict_3 = predict(Iris3_test) print("判斷為" + Iris3_class + "的個數") print(predict_3) print("準確率為") print((predict_1 + predict_2 + predict_3) / 60 * 100, "%")總結
以上是生活随笔為你收集整理的LAD线性判别分析鸢尾花预测的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 损失函数MSELoss和CELoss
- 下一篇: python mse_python 计算