TensorRTx-YOLOv5工程解读(一)
TensorRTx-YOLOv5工程解讀(一)
權(quán)重生成:gen_wts.py
作者先是使用了gen_wts.py這個腳本去生成wts文件。顧名思義,這個.wts文件里面存放的就是.pt文件的權(quán)重。腳本內(nèi)容如下:
import sys import argparse import os import struct import torch from utils.torch_utils import select_devicedef parse_args():parser = argparse.ArgumentParser(description='Convert .pt file to .wts')parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')args = parser.parse_args()if not os.path.isfile(args.weights):raise SystemExit('Invalid input file')if not args.output:args.output = os.path.splitext(args.weights)[0] + '.wts'elif os.path.isdir(args.output):args.output = os.path.join(args.output,os.path.splitext(os.path.basename(args.weights))[0] + '.wts')return args.weights, args.outputpt_file, wts_file = parse_args()# Initialize device = select_device('cpu') # Load model model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 model.to(device).eval()with open(wts_file, 'w') as f:f.write('{}\n'.format(len(model.state_dict().keys())))for k, v in model.state_dict().items():vr = v.reshape(-1).cpu().numpy()f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f' ,float(vv)).hex())f.write('\n')第一個函數(shù)parse_args()就是正常處理輸入的命令行參數(shù),不多做贅述。
主函數(shù)內(nèi),先是設(shè)置設(shè)備為CPU,再load進pt文件獲得model并轉(zhuǎn)成FP32格式。并設(shè)置模型的device和eval模式。
設(shè)置完畢后,作者保存權(quán)重文件,其中權(quán)重文件的內(nèi)容是作者自定義的。第一行存入的是model的keys的個數(shù),再分別遍歷pt文件內(nèi)的每一個權(quán)重,保存為該層名稱 該層參數(shù)量 16進制權(quán)重。
權(quán)重讀取:common.cpp
首先順著之前的思路,看看作者是如何load權(quán)重的。
// TensorRT weight files have a simple space delimited format: // [type] [size] <data x size in hex> std::map<std::string, Weights> loadWeights(const std::string file) {std::cout << "Loading weights: " << file << std::endl;std::map<std::string, Weights> weightMap;// Open weights filestd::ifstream input(file);assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");// Read number of weight blobsint32_t count;input >> count;assert(count > 0 && "Invalid weight map file.");while (count--){Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size;// Read name and type of blobstd::string name;input >> name >> std::dec >> size;wt.type = DataType::kFLOAT;// Load blobuint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));for (uint32_t x = 0, y = size; x < y; ++x){input >> std::hex >> val[x];}wt.values = val;wt.count = size;weightMap[name] = wt;}return weightMap; }此為loadWeight()函數(shù)。作者此處使用了std::map容器。map容器在OpenCV和OpenVINO中本身就是大量使用的,所以除了vector之外,也需要掌握map的使用。后面需要往這個<std::string, Weights>型的map中添加權(quán)重信息。
同時應(yīng)該注意,此處的Weights類型在TensorRT的NvInferRuntime.h頭文件中有定義:
class Weights { public:DataType type; //!< The type of the weights.const void* values; //!< The weight values, in a contiguous array.int64_t count; //!< The number of weights in the array. };作者使用了std::ifstream進行輸入流變量的定義,并設(shè)置了一些變量。代碼中的input >> count就是將.wts文件中的第一行的算子數(shù)傳遞給count這個變量,從而構(gòu)建while循環(huán)。
在While循環(huán)中,作者先定義了Weights型的wt變量,其類型為DataType::kFLOAT,values直接初始化為nullptr,count初始化一個0在上面即可。
這一句input >> name >> std::dec >> size是將input中的第一部分:權(quán)重的名稱,賦值給name變量,再將緊跟著name后的size推入給size變量。具體的形式可以參考之前分析gen_wts.py腳本中的權(quán)重生成的部分。作者之所以要存入這一算子的權(quán)重的size,就是為了方便分配空間大小。聲明指針val指向一個大小為sizeof(val) * size的uint32_t的數(shù)組,并且將input中這一行的權(quán)重全部推入給val這個數(shù)組即可。
這一步完成后,設(shè)置Weights的values成員為val,count成員為size,并將name作為weightMap的keys,wt作為其values即可。
至此,模型權(quán)重加載完畢。
總結(jié)
以上是生活随笔為你收集整理的TensorRTx-YOLOv5工程解读(一)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 蓝桥杯算法训练KAc给糖果贪心-pyth
- 下一篇: 永远不说喜欢你