pytorch选出数据中的前k个最大(最小)值及其索引
生活随笔
收集整理的這篇文章主要介紹了
pytorch选出数据中的前k个最大(最小)值及其索引
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
選擇最大值及其索引,大家都知道使用max(),argmax()函數。
那么如何返回前k個最大值呢,這在我們計算topK準確率的時候很有必要:
在torch中,我們可以使用sort函數來實現:
a, idx1 = torch.sort(data, descending=True)#descending為alse,升序,為True,降序 idx = idx1[:k]Return:
a:排好序的數據
idx1:對應排序數據的索引
因此只需設置k的大小,就可以截取到前k個最大值的索引。這里若數據是tensor則用torch,若是list或ndarray,可以用numpy。
總結
以上是生活随笔為你收集整理的pytorch选出数据中的前k个最大(最小)值及其索引的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 记录之使用3080ti运行tensorf
- 下一篇: AssertionError: Inva