动手画混淆矩阵(Confusion Matrix)(含代码)
生活随笔
收集整理的這篇文章主要介紹了
动手画混淆矩阵(Confusion Matrix)(含代码)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
- 1、混淆矩陣:Confusion Matrix
- 2、怎么畫(新)?
- 3、怎么用?
網上關于混淆矩陣的代碼參差不齊,沒找到可用的線程的代碼,所以自己嘗試寫了下
1、混淆矩陣:Confusion Matrix
首先它長這樣:
怎么看?
Confusion Matrix最廣泛的應用應該是分類,比如圖中是7分類的真實標簽和預測標簽的效果。
首先圖中表明了縱軸是truth label,橫軸是predicted label,那么對于第一行第一個0.60的含義是:本來是angry標簽的圖,我的模型正確分類成angry的比例是60%,也即是angry這一類模型分類正確的精度只有60%。同時模型將angry分類成了happy的圖占比0.04%,其他的以此類推。
注意:因為本身是angry,模型預測成7種類的數量占比。所以每一行的和為100%。
同時對于fear標簽,模型分類成fear的占比41%,分類成sad的占比為20%,我們可以認為模型不能很好區分fear和sad兩種類別。
2、怎么畫(新)?
這里直接給出代碼,在下一節中直接使用即可
import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrixdef draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):"""@param label_true: 真實標簽,比如[0,1,2,7,4,5,...]@param label_pred: 預測標簽,比如[0,5,4,2,1,4,...]@param label_name: 標簽名字,比如['cat','dog','flower',...]@param normlize: 是否設元素為百分比形式@param title: 圖標題@param pdf_save_path: 是否保存,是則為保存路徑pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式@param dpi: 保存到文件的分辨率,論文一般要求至少300dpi@return:example:draw_confusion_matrix(label_true=y_gt,label_pred=y_pred,label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],normlize=True,title="Confusion Matrix on Fer2013",pdf_save_path="Confusion_Matrix_on_Fer2013.png",dpi=300)"""cm = confusion_matrix(label_true, label_pred)if normlize:row_sums = np.sum(cm, axis=1) # 計算每行的和cm = cm / row_sums[:, np.newaxis] # 廣播計算每個元素占比plt.imshow(cm, cmap='Blues')plt.title(title)plt.xlabel("Predict label")plt.ylabel("Truth label")plt.yticks(range(label_name.__len__()), label_name)plt.xticks(range(label_name.__len__()), label_name, rotation=45)plt.tight_layout()plt.colorbar()for i in range(label_name.__len__()):for j in range(label_name.__len__()):color = (1, 1, 1) if i == j else (0, 0, 0) # 對角線字體白色,其他黑色value = float(format('%.2f' % cm[i, j]))plt.text(i, j, value, verticalalignment='center', horizontalalignment='center', color=color)# plt.show()if not pdf_save_path is None:plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)3、怎么用?
給出一個簡單的實例:
labels_name=['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']y_gt=[] y_pred=[] for index, (labels, imgs) in enumerate(test_loader):labels_pd = model(imgs)predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1) # array([0,5,1,6,3,...],dtype=int64)labels_np = labels.numpy() # array([0,5,0,6,2,...],dtype=int64)y_pred.append(labels_np)y_gt.append(labels_np)draw_confusion_matrix(label_true=y_gt, # y_gt=[0,5,1,6,3,...]label_pred=y_pred, # y_pred=[0,5,1,6,3,...]label_name=["An", "Di", "Fe", "Ha", "Sa", "Su", "Ne"],normlize=True,title="Confusion Matrix on Fer2013",pdf_save_path="Confusion_Matrix_on_Fer2013.jpg",dpi=300)- cpu().detach():從device上獲取數據
- .numpy():將tensor類型轉換為numpy類型
在我的模型上的結果:
總結
以上是生活随笔為你收集整理的动手画混淆矩阵(Confusion Matrix)(含代码)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 项目3抽象类与纯虚函数
- 下一篇: # D - Staircase Sequ