yoloV5模型训练教程并进行量化
生活随笔
收集整理的這篇文章主要介紹了
yoloV5模型训练教程并进行量化
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
yoloV5模型訓(xùn)練教程
數(shù)據(jù)標(biāo)注
數(shù)據(jù)標(biāo)注我們要用labelimg
pip install labelimg百度爬蟲爬取圖像
import os import re import sys import urllib import json import socket import urllib.request import urllib.parse import urllib.error # 設(shè)置超時 from random import randint import timetimeout = 5 socket.setdefaulttimeout(timeout)class Crawler:# 睡眠時長__time_sleep = 0.1__amount = 0__start_amount = 0__counter = 0headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:23.0) Gecko/20100101 Firefox/23.0'}__per_page = 30# 獲取圖片url內(nèi)容等# t 下載圖片時間間隔def __init__(self, t=0.1):self.time_sleep = t# 獲取后綴名@staticmethoddef get_suffix(name):m = re.search(r'\.[^\.]*$', name)if m.group(0) and len(m.group(0)) <= 5:return m.group(0)else:return '.jpeg'# 保存圖片def save_image(self, rsp_data, word):if not os.path.exists("./" + word):os.mkdir("./" + word)# 判斷名字是否重復(fù),獲取圖片長度self.__counter = len(os.listdir('./' + word)) + 1for image_info in rsp_data['data']:try:if 'replaceUrl' not in image_info or len(image_info['replaceUrl']) < 1:continueobj_url = image_info['replaceUrl'][0]['ObjUrl']thumb_url = image_info['thumbURL']url = 'https://image.baidu.com/search/down?tn=download&ipn=dwnl&word=download&ie=utf8&fr=result&url=%s&thumburl=%s' % (urllib.parse.quote(obj_url), urllib.parse.quote(thumb_url))time.sleep(self.time_sleep)suffix = self.get_suffix(obj_url)# 指定UA和referrer,減少403opener = urllib.request.build_opener()opener.addheaders = [('User-agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.116 Safari/537.36'),]urllib.request.install_opener(opener)# 保存圖片filepath = './{}/PME_{}_A{}'.format(word, randint(1000000, 500000000), str(self.__counter) + str(suffix))for _ in range(5):urllib.request.urlretrieve(url, filepath)if os.path.getsize(filepath) >= 5:breakif os.path.getsize(filepath) < 5:print("下載到了空文件,跳過!")os.unlink(filepath)continueexcept urllib.error.HTTPError as urllib_err:print(urllib_err)continueexcept Exception as err:time.sleep(1)print(err)print("產(chǎn)生未知錯誤,放棄保存")continueelse:print("圖+1,已有" + str(self.__counter) + "張圖")self.__counter += 1return# 開始獲取def get_images(self, word):search = urllib.parse.quote(word)# pn int 圖片數(shù)pn = self.__start_amountwhile pn < self.__amount:url = 'https://image.baidu.com/search/acjson?tn=resultjson_com&ipn=rj&ct=201326592&is=&fp=result&queryWord=%s&cl=2&lm=-1&ie=utf-8&oe=utf-8&adpicid=&st=-1&z=&ic=&hd=&latest=©right=&word=%s&s=&se=&tab=&width=&height=&face=0&istype=2&qc=&nc=1&fr=&expermode=&force=&pn=%s&rn=%d&gsm=1e&1594447993172=' % (search, search, str(pn), self.__per_page)# 設(shè)置header防403try:time.sleep(self.time_sleep)req = urllib.request.Request(url=url, headers=self.headers)page = urllib.request.urlopen(req)rsp = page.read()except UnicodeDecodeError as e:print(e)print('-----UnicodeDecodeErrorurl:', url)except urllib.error.URLError as e:print(e)print("-----urlErrorurl:", url)except socket.timeout as e:print(e)print("-----socket timout:", url)else:# 解析jsontry:rsp_data = json.loads(rsp)self.save_image(rsp_data, word)# 讀取下一頁print("下載下一頁")pn += 60except Exception as e:continuefinally:page.close()print("下載任務(wù)結(jié)束")returndef start(self, word, total_page=2, start_page=1, per_page=30):"""爬蟲入口:param word: 抓取的關(guān)鍵詞:param total_page: 需要抓取數(shù)據(jù)頁數(shù) 總抓取圖片數(shù)量為 頁數(shù) x per_page:param start_page:起始頁碼:param per_page: 每頁數(shù)量:return:"""self.__per_page = per_pageself.__start_amount = (start_page - 1) * self.__per_pageself.__amount = total_page * self.__per_page + self.__start_amountself.get_images(word)if __name__ == '__main__':crawler = Crawler(0.05) # 抓取延遲為 0.05crawler.start('玩手機')標(biāo)注完成后,每張圖像會生成對應(yīng)的xml標(biāo)注文件
數(shù)據(jù)預(yù)處理
創(chuàng)建convert_data.py文件,內(nèi)容如下:
# -*- coding: utf-8 -*-import xml.etree.ElementTree as ET from tqdm import tqdm import os from os import getcwddef convert(size, box):dw = 1. / (size[0])dh = 1. / (size[1])x = (box[0] + box[1]) / 2.0 - 1y = (box[2] + box[3]) / 2.0 - 1w = box[1] - box[0]h = box[3] - box[2]x = x * dww = w * dwy = y * dhh = h * dhreturn x, y, w, hdef convert_annotation(image_id):# try:in_file = open('VOCData/images/{}.xml'.format(image_id), encoding='utf-8')out_file = open('VOCData/labels/{}.txt'.format(image_id),'w', encoding='utf-8')tree = ET.parse(in_file)root = tree.getroot()size = root.find('size')w = int(size.find('width').text)h = int(size.find('height').text)for obj in root.iter('object'):difficult = obj.find('difficult').textcls = obj.find('name').textif cls not in classes or int(difficult) == 1:continuecls_id = classes.index(cls)xmlbox = obj.find('bndbox')b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),float(xmlbox.find('ymax').text))b1, b2, b3, b4 = b# 標(biāo)注越界修正if b2 > w:b2 = wif b4 > h:b4 = hb = (b1, b2, b3, b4)bb = convert((w, h), b)out_file.write(str(cls_id) + " " +" ".join([str(a) for a in bb]) + '\n')# except Exception as e:# print(e, image_id)if __name__ == '__main__':sets = ['train', 'val']image_ids = [v.split('.')[0]for v in os.listdir('VOCData/images/') if v.endswith('.xml')]split_num = int(0.95 * len(image_ids))classes = ['face', 'normal', 'phone', 'write','smoke', 'eat', 'computer', 'sleep']if not os.path.exists('VOCData/labels/'):os.makedirs('VOCData/labels/')list_file = open('train.txt', 'w')for image_id in tqdm(image_ids[:split_num]):list_file.write('VOCData/images/{}.jpg\n'.format(image_id))convert_annotation(image_id)list_file.close()list_file = open('val.txt', 'w')for image_id in tqdm(image_ids[split_num:]):list_file.write('VOCData/images/{}.jpg\n'.format(image_id))convert_annotation(image_id)list_file.close()運行結(jié)束后,可以看到VOCData/labels下生成了對應(yīng)的txt文件
在data文件夾下創(chuàng)建myvoc.yaml文件
內(nèi)容如下:
train: train.txt val: val.txt# number of classes nc: 8# class names names: ["face", "normal", "phone", "write", "smoke", "eat", "computer", "sleep"]下載預(yù)訓(xùn)練模型
我訓(xùn)練yolov5m這個模型,因此將它的預(yù)訓(xùn)練模型下載到weights文件夾下:
模型訓(xùn)練
修改models/yolov5m.yaml下的類別數(shù):
python train.py --img 640 --batch 4 --epoch 300 --data ./data/myvoc.yaml --cfg ./models/yolov5m.yaml --weights weights/yolov5m.pt --workers 0模型推理測試
訓(xùn)練結(jié)束后在 run/train/exp/weights 文件夾下會生成訓(xùn)練好的兩個模型文件,我們將 last.pt 取出放到根目錄下,然后運行:
python detect.py --source data/images --weights last.pt --conf 0.25模型量化
這時我們注意到,訓(xùn)練好的 last.pt 有172MB,而官方給出的 yolov5m.pt 只有 40MB,這時候我們需要導(dǎo)出半精度模型重新保存,創(chuàng)建slim.py文件
python slim.py --in_weights last.pt --out_weights slim_model.pt --device 0slim.py
import os import torchimport torch import torch.nn as nn from tqdm import tqdmdef autopad(k, p=None): # Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn pclass Conv(nn.Module):# Standard convolution# ch_in, ch_out, kernel, stride, padding, groupsdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):super(Conv, self).__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p),groups=g, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = nn.Hardswish() if act else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def fuseforward(self, x):return self.act(self.conv(x))class Ensemble(nn.ModuleList):# Ensemble of modelsdef __init__(self):super(Ensemble, self).__init__()def forward(self, x, augment=False):y = []for module in self:y.append(module(x, augment)[0])# y = torch.stack(y).max(0)[0] # max ensemble# y = torch.cat(y, 1) # nms ensembley = torch.stack(y).mean(0) # mean ensemblereturn y, None # inference, train outputdef attempt_load(weights, map_location=None):model = Ensemble()for w in weights if isinstance(weights, list) else [weights]:# load FP32 modelmodel.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval())# Compatibility updatesfor m in tqdm(model.modules()):if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:m.inplace = True # pytorch 1.7.0 compatibilityelif type(m) is Conv:m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibilityif len(model) == 1:return model[-1] # return modelelse:print('Ensemble created with %s\n' % weights)for k in ['names', 'stride']:setattr(model, k, getattr(model[-1], k))return model # return ensembledef select_device(device='', batch_size=None):# device = 'cpu' or '0' or '0,1,2,3'cpu_request = device.lower() == 'cpu'if device and not cpu_request: # if device requested other than 'cpu'os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variableassert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablitycuda = False if cpu_request else torch.cuda.is_available()if cuda:c = 1024 ** 2 # bytes to MBng = torch.cuda.device_count()if ng > 1 and batch_size: # check that batch_size is compatible with device_countassert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)x = [torch.cuda.get_device_properties(i) for i in range(ng)]s = f'Using torch {torch.__version__} 'for i in range(0, ng):if i == 1:s = ' ' * len(s)return torch.device('cuda:0' if cuda else 'cpu')if __name__ == '__main__':import argparseparser = argparse.ArgumentParser()parser.add_argument('--in_weights', type=str,default='last.pt', help='initial weights path')parser.add_argument('--out_weights', type=str,default='slim_model.pt', help='output weights path')parser.add_argument('--device', type=str, default='0', help='device')opt = parser.parse_args()device = select_device(opt.device)model = attempt_load(opt.in_weights, map_location=device)model.to(device).eval()model.half()torch.save(model, opt.out_weights)print('done.')print('-[INFO] before: {} kb, after: {} kb'.format(os.path.getsize(opt.in_weights), os.path.getsize(opt.out_weights)))總結(jié)
以上是生活随笔為你收集整理的yoloV5模型训练教程并进行量化的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 项目交付的问题
- 下一篇: java.lang.Unsatisfie