Pytorch 之 TSM(Time Shift Module)测试部分源码详解
本文致力于將文中的一些細節給大家解釋清楚,如果有照顧不到的細節,還請見諒,歡迎留言討論
1.參數部分:
parser = argparse.ArgumentParser(description="TSM testing on the full validation set") parser.add_argument('dataset', type=str)# may contain splits parser.add_argument('--weights', type=str, default=None) parser.add_argument('--test_segments', type=str, default=25) parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D') parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') parser.add_argument('--full_res', default=False, action="store_true",help='use full resolution 256x256 for test as in Non-local I3D')parser.add_argument('--test_crops', type=int, default=1) parser.add_argument('--coeff', type=str, default=None) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',help='number of data loading workers (default: 8)')# for true test parser.add_argument('--test_list', type=str, default=None) parser.add_argument('--csv_file', type=str, default=None)parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--img_feature_dim',type=int, default=256) parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') parser.add_argument('--pretrain', type=str, default='imagenet')args = parser.parse_args()這里我們主要關注的參數應該是 test_segments,由于文章才用的是跨步采樣的方式。因此這里的test_segments表示將視頻等分成的份兒數,從每份中隨機抽取一幀。注意dense-sample,twice-sample以及test_crops參數,接下來我們還會介紹,其他的不影響閱讀代碼的參數我們不再介紹。
2.數據處理
讀者部分之前,你必須知道的兩個函數是zip() 和 enmunate(). 不然循環會讓你暈頭轉向;
1.for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
#這表示從zip中的三個list中同時返回三個元素
2.enumerate(data_loader),返回一個list,元素為索引和值本身組成的tuple
weights_list = args.weights.split(',') test_segments_list = [int(s) for s in args.test_segments.split(',')] assert len(weights_list) == len(test_segments_list) #均為1 if args.coeff is None:coeff_list = [1] * len(weights_list) else:coeff_list = [float(c) for c in args.coeff.split(',')]if args.test_list is not None:test_file_list = args.test_list.split(',') else:test_file_list = [None] * len(weights_list)data_iter_list = [] net_list = [] modality_list = []total_num = None for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) #設置模型參數if 'RGB' in this_weights: modality = 'RGB'else:modality = 'Flow'this_arch = this_weights.split('TSM_')[1].split('_')[2] #resnet50modality_list.append(modality)num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,modality) #獲得類別總數,訓練集,驗證集等print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))'''Created on 2020年5月10日定義net,并加載參數@author: DELL'''net = TSN(num_class, this_test_segments if is_shift else 1, modality,base_model=this_arch,consensus_type=args.crop_fusion_type,img_feature_dim=args.img_feature_dim,pretrain=args.pretrain,is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,non_local='_nl' in this_weights,)print(net)if 'tpool' in this_weights:from first.ops.temporal_shift import make_temporal_poolmake_temporal_pool(net.base_model, this_test_segments) # since DataParallelcheckpoint = torch.load(this_weights)checkpoint = checkpoint['state_dict']# base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}replace_dict = {'base_model.classifier.weight': 'new_fc.weight','base_model.classifier.bias': 'new_fc.bias',}for k, v in replace_dict.items():if k in base_dict:base_dict[v] = base_dict.pop(k)net.load_state_dict(base_dict)'''模型加載結束,這部分有不懂得可以參見我的另一篇博客'''input_size = net.scale_size if args.full_res else net.input_size'''Created on 2020年5月10日選擇數據的隨機處理模式這里給大家簡單介紹,有個初步的概念 @author: DELL'''if args.test_crops == 1: #為1,則先放縮,再裁剪只留中間的符合尺寸的部分cropping = torchvision.transforms.Compose([GroupScale(net.scale_size),GroupCenterCrop(input_size),])elif args.test_crops == 3: # do not flip, so only 3 crops #為3,先放縮,后裁剪,然后留下左邊,右邊,中間三個裁剪數據,不翻轉,故一一張圖片擴充為了3張cropping = torchvision.transforms.Compose([GroupFullResSample(input_size, net.scale_size, flip=False)])elif args.test_crops == 5: # do not flip, so only 5 crops #為5,先放縮,后裁剪,然后留下左上,左下,右上,右下,中間5個裁剪數據,不翻轉,故一一張圖片擴充為了5張cropping = torchvision.transforms.Compose([GroupOverSample(input_size, net.scale_size, flip=False)])elif args.test_crops == 10:#為10,先放縮,后裁剪,然后留下左上,左下,右上,右下,中間5個裁剪數據,翻轉翻倍,故一一張圖片擴充為了10張cropping = torchvision.transforms.Compose([GroupOverSample(input_size, net.scale_size)])else:raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))'''定義data_loader,data:[K,Batch_size,test_crops,test_segments,224,224,3],可以暫時這樣理解,主要能理解K控制底下的循環即可。這里不再詳細解釋,不影響閱讀代碼dense_sample,表示密集采樣,在一個視頻的每一段中隨機取10幀,然后對每一幀進行上述的crop等處理,則10幀擴展成了10*test_crops幀twice_sample,表示采樣兩次 ,然后對每一幀進行上述的crop等處理,則2幀擴展成了2*test_crops幀 '''data_loader = torch.utils.data.DataLoader(TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,new_length=1 if modality == "RGB" else 5,modality=modality,image_tmpl=prefix,test_mode=True,remove_missing=len(weights_list) == 1,transform=torchvision.transforms.Compose([cropping,Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),GroupNormalize(net.input_mean, net.input_std),]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),batch_size=args.batch_size, shuffle=False,num_workers=args.workers, pin_memory=True,)'''設置gpu'''if args.gpus is not None:devices = [args.gpus[i] for i in range(args.workers)]else:devices = list(range(args.workers))'''設置數據并行'''net = torch.nn.DataParallel(net.cuda())net.eval()'''返回一個list,元素形式為tuple:(index,data,label)'''data_gen = enumerate(data_loader)if total_num is None:total_num = len(data_loader.dataset)else:assert total_num == len(data_loader.dataset)data_iter_list.append(data_gen) #data部分[1,K,batch_size,test_crops,test_segments,224,224,3]net_list.append(net)3.測試:
def eval_video(video_data, net, this_test_segments, modality):net.eval()with torch.no_grad():i, data, label = video_databatch_size = label.numel() #返回數組中的元素個數num_crop = args.test_crops #這里我們用到了test_crops的值if args.dense_sample:num_crop *= 10 # 10 clips for testing when using dense sample #這里為什么這樣操作大家也應該明白了if args.twice_sample:num_crop *= 2if modality == 'RGB':length = 3elif modality == 'Flow':length = 10elif modality == 'RGBDiff':length = 18else:raise ValueError("Unknown modality "+ modality)data_in = data.view(-1, length, data.size(2), data.size(3))if is_shift: #如果有時間位移模塊,則調整輸入為下列格式data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))rst = net(data_in) #[batch_size * num_crop,174]rst = rst.reshape(batch_size, num_crop, -1).mean(1) #[batch_size, 174]if args.softmax:# take the softmax to normalize the output to probabilityrst = F.softmax(rst, dim=1) #按行進行softmaxrst = rst.data.cpu().numpy().copy()if net.module.is_shift:rst = rst.reshape(batch_size, num_class) else:rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))return i, rst, labelproc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else total_num top1 = AverageMeter() top5 = AverageMeter()for i, data_label_pairs in enumerate(zip(*data_iter_list)):#*表示降維,K控制大循環次數!!K = total / batch_sizewith torch.no_grad():if i >= max_num:breakthis_rst_list = []this_label = Nonefor n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):rst = eval_video((i, data, label), net, n_seg, modality)#turple返回一個turplethis_rst_list.append(rst[1])# rst[1]表示預測的類別部分,這里我當時看了很久,這里的rst與eval_video函數中返回的rst名字一致,但其實不是一個東西this_label = labelassert len(this_rst_list) == len(coeff_list) #1 = 1for i_coeff in range(len(this_rst_list)):this_rst_list[i_coeff] *= coeff_list[i_coeff]ensembled_predict = sum(this_rst_list) / len(this_rst_list) #sum(表示,沿最高維相加) #[batch_size,174]for p, g in zip(ensembled_predict, this_label.cpu().numpy()):output.append([p[None, ...], g]) #[[data[0],label],[[data[1]],label],...] 共total_num個元素,每個元素的尺寸[[1,174],1],注意這里每次循環添加batch_size個元素cnt_time = time.time() - proc_start_timeprec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) #詳見accuracy函數top1.update(prec1.item(), this_label.numel())#詳見AverageMetertop5.update(prec5.item(), this_label.numel())if i % 20 == 0:print('video {} done, total {}/{}, average {:.3f} sec/video, ''moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))video_pred = [np.argmax(x[0]) for x in output] #詳見上面output的出處! video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]video_labels = [x[1] for x in output]?
4.相關函數和類:
這里你必須知道的函數為pytorch中的topk,非常好用的函數
a,b = data.topk(maxk,dims,True,True),這里返回的a是data中前maxk大的元素,b是其索引,dims = 1,按列,= 0,按行!
class AverageMeter(object):"""Computes and stores the average and current value"""def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1): #計算平均值self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.countdef accuracy(output, target, topk=(1,)):"""Computes the precision@k for the specified values of k"""maxk = max(topk)batch_size = target.size(0) #72_, pred = output.topk(maxk, 1, True, True)#取指定維度上的幾個最大值 第一個返回值為值,第二個為值得位置 pred = pred.t()#轉置[5,72]correct = pred.eq(target.view(1, -1).expand_as(pred)) #[5,72] 由1,0構成res = []for k in topk:correct_k = correct[:k].view(-1).float().sum(0) #.view(-1)轉換為行向量res.append(correct_k.mul_(100.0 / batch_size))return resdef parse_shift_option_from_log_name(log_name):if 'shift' in log_name:strings = log_name.split('_')for i, s in enumerate(strings):if 'shift' in s:breakreturn True, int(strings[i].replace('shift', '')), strings[i + 1]else:return False, None, None#僅供大家參考,有筆誤或什么歡迎指正,大家交流提高?
?
總結
以上是生活随笔為你收集整理的Pytorch 之 TSM(Time Shift Module)测试部分源码详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch对张量的一些常用处理以及n
- 下一篇: Tensorflow中文文档