TorchVision中使用FasterRCNN+ResNet50+FPN进行目标检测
? ? ? TorchVision中給出了使用ResNet-50-FPN主干(backbone)構建Faster R-CNN的pretrained模型,模型存放位置為https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth,可通過fasterrcnn_resnet50_fpn函數下載,此函數實現在torchvison/models/detection/faster_rcnn.py中,下載后在Ubuntu上存放在~/.cache/torch/hub/checkpoints目錄下,在Windows上存放在C:\Users\spring\.cache\torch\hub\checkpoints目錄下,其中spring為用戶名。
? ? ? 模型的輸入是一個tensor列表;每個shape都是[c,h,w];每個shape指定一副圖像,并且圖像中值的范圍為[0,1],即已做過normalized;不同的圖像可以有不同的大小,即此模型支持非固定大小圖像的輸入。
? ? ? 模型的行為取決于它是處于訓練模式(training)還是評估模式(evaluation):
? ? ? (1).在訓練期間,模型需要輸入tensors和targets(字典列表),包含boxes和labels。
? ? ? boxes類型為FloatTensor[N,4],其中N為圖像數;4為[x1,y1,x2,y2],即ground-truth box的左上和右下角坐標,它們的值要合理范圍內。
? ? ? labels類型為Int64Tensor[N],每個ground-truth box的class label。
? ? ? (2).在推理(inference)過程中,模型只需要輸入tensors,并返回后處理的預測(post-processed predictions),此預測類型為List[Dict[Tensor]],對應每個輸入圖像。
? ? ? Dict字段內容除包含boxes和labels外,還包含scores。
? ? ? scores類型為Tensor[N],每個預測的分值,按照值從大到小的順序排列。
? ? ? 模型是通過COCO數據集訓練獲得的,COCO數據集的介紹參考:https://blog.csdn.net/fengbingchun/article/details/121308708
? ? ? FPN全稱為Feature Pyramid Networks,即特征金字塔網絡,是一種多尺度的目標檢測算法,FPN的介紹參考:https://blog.csdn.net/fengbingchun/article/details/87359191
? ? ? ResNet即Residual Networks,也稱為殘差網絡,是為了解決深度神經網絡的”退化(degradation)”問題。ResNet-50中的50指此網絡有50層。ResNet介紹參考:https://blog.csdn.net/fengbingchun/article/details/114167581
? ? ? Faster R-CNN為目標檢測算法,為RPN(Region Proposal Network)和Fast R-CNN的結合。Faster R-CNN介紹參考:https://blog.csdn.net/fengbingchun/article/details/87195597
? ? ? 以下為測試代碼:
import torch
from torchvision import models
from torchvision import transforms
import cv2'''
Note: conda pytorch install opencv
windows: conda install opencv # python=3.8.8, opencv=4.0.1
ubuntu: pip3 install opencv-python # python=3.7.11, opencv=4.5.4
'''images_path = "../../data/image/"
images_name = ["1.jpg", "2.jpg", "4.jpg"]
images_data = [] # opencv
tensor_data = [] # pytorch tensorfor name in images_name:img = cv2.imread(images_path + name)print(f"name: {images_path+name}, opencv image shape: {img.shape}") # (w,h,c)images_data.append(img)transform = transforms.Compose([transforms.ToTensor()])tensor = transform(img) # Normalized Tensor image: [0., 1.]print(f"tensor shape: {tensor.shape}, max: {torch.max(tensor)}, min: {torch.min(tensor)}") # (c,h,w)tensor_data.append(tensor)# reference: torchvison/models/detection/faster_rcnn.py
# 使用ResNet-50-FPN(Feature Pyramid Networks, 特征金字塔網絡)構建Faster RCNN模型
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
#print(model) # 可查看模型結構
model.eval() # 推理
predictions = model(tensor_data) # result: list: boxes (FloatTensor[N, 4]), labels (Int64Tensor[N]), scores (Tensor[N])
#print(predictions)coco_labels_name = ["unlabeled", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat","traffic light", "fire hydrant", "street sign", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse","sheep", "cow", "elephant", "bear", "zebra", "giraffe", "hat", "backpack", "umbrella", "shoe","eye glasses", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball", "kite", "baseball bat","baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "plate", "wine glass", "cup", "fork", "knife","spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza","donut", "cake", "chair", "couch", "potted plant", "bed", "mirror", "dining table", "window", "desk","toilet", "door", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven","toaster", "sink", "refrigerator", "blender", "book", "clock", "vase", "scissors", "teddy bear", "hair drier","toothbrush", "hair brush"] # len = 92for x in range(len(predictions)):pred = predictions[x]scores = pred["scores"]mask = scores > 0.5 # 只取scores值大于0.5的部分boxes = pred["boxes"][mask].int().detach().numpy() # [x1, y1, x2, y2]labels = pred["labels"][mask]scores = scores[mask]print(f"prediction: boxes:{boxes}, labels:{labels}, scores:{scores}")img = images_data[x]for idx in range(len(boxes)):cv2.rectangle(img, (boxes[idx][0], boxes[idx][1]), (boxes[idx][2], boxes[idx][3]), (255, 0, 0))cv2.putText(img, coco_labels_name[labels[idx]]+" "+str(scores[idx].detach().numpy()), (boxes[idx][0]+10, boxes[idx][1]+10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)cv2.imshow("image", img)cv2.waitKey(1000)cv2.imwrite(images_path+"result_"+images_name[x], img)print("test finish")
? ? ? 說明:
? ? ? (1).輸入圖像既可以是彩色圖也可以是灰度圖,即channel為3或1均可。
? ? ? (2).輸入圖像的大小不受限制,一組圖像可以大小不一致。
? ? ? (3).輸入圖像要求normalized到[0., 1.]。
? ? ? (4).執行結果僅顯示scores值大于0.5的情況。
? ? ? (5).測試代碼中類別數為92而不是80,92=1+11+80。其中1為id為0,label name為unlabeled;11為從COCO中移除的label,如street sign;80為真正的label數,如person。詳細參考:https://github.com/nightrome/cocostuff/blob/master/labels.md
? ? ? (6).結果顯示中有冗余的檢測框,可以通過NMS(Non-Maximum Suppression)非極大值抑制算法移除。
? ? ? 執行結果如下:以下原始測試圖像來自網絡
?
? ? ? GitHub:?GitHub - fengbingchun/PyTorch_Test: PyTorch's usage
總結
以上是生活随笔為你收集整理的TorchVision中使用FasterRCNN+ResNet50+FPN进行目标检测的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch中torchvision介
- 下一篇: TorchVision中通过AlexNe