from __future__ import divisionfrom models import*from utils.utils import*from utils.datasets import*from utils.augmentations import*from utils.transforms import*import os
import sys
import time
import datetime
import argparsefrom PIL import Imageimport torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator
from flask import Flask,request,make_response,render_template
import socket
from time import sleep
myhost = socket.gethostbyname(socket.gethostname())
app = Flask(__name__)
igpath ='/home/heziyi/pic/'@app.route('/', methods=['GET','POST'])# 使用methods參數處理不同HTTP方法defhome():return render_template('index.html')#@app.route('/img/<string:filename>',methods=['GET'])@app.route('/img/',methods=['GET'])defdisplay():if request.method =='GET':# if filename is None:# pass# else:image =open("/static/he_21.png","rb").read()# image = open(igpath+filename,"rb").read()response = make_response(image)response.headers['Content-Type']='image/jpg'return response
if __name__ =="__main__":parser = argparse.ArgumentParser()parser.add_argument("--image_folder",type=str, default="data/custom/dd",help="path to dataset")parser.add_argument("--model_def",type=str, default="config/yolov3-custom.cfg",help="path to model definition file")parser.add_argument("--weights_path",type=str, default="checkpoints/ckpt_88.pth",help="path to weights file")parser.add_argument("--class_path",type=str, default="data/custom/classes.names",help="path to class label file")parser.add_argument("--conf_thres",type=float, default=0.8,help="object confidence threshold")parser.add_argument("--nms_thres",type=float, default=0.4,help="iou thresshold for non-maximum suppression")parser.add_argument("--batch_size",type=int, default=1,help="size of the batches")parser.add_argument("--n_cpu",type=int, default=0,help="number of cpu threads to use during batch generation")parser.add_argument("--img_size",type=int, default=416,help="size of each image dimension")parser.add_argument("--checkpoint_model",type=str,default="checkpoints/ckpt_88.pth",help="path to checkpoint model")opt = parser.parse_args()print(opt)#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = torch.device('cpu')os.makedirs("../output", exist_ok=True)# Set up modellamodel = Darknet(opt.model_def, img_size=opt.img_size).to(device)if opt.weights_path.endswith(".weights"):# Load darknet weightsmodel.load_darknet_weights(opt.weights_path)else:# Load checkpoint weightsmodel.load_state_dict(torch.load(opt.weights_path,map_location=device))#cpu!!!!!!model.eval()# Set in evaluation modedataloader = DataLoader(ImageFolder(opt.image_folder, transform= \transforms.Compose([DEFAULT_TRANSFORMS, Resize(opt.img_size)])),batch_size=opt.batch_size,shuffle=False,num_workers=opt.n_cpu,)classes = load_classes(opt.class_path)# Extracts class labels from fileTensor = torch.cuda.FloatTensor if torch.cuda.is_available()else torch.FloatTensorimgs =[]# Stores image pathsimg_detections =[]# Stores detections for each image indexprint("\nPerforming object detection:")prev_time = time.time()for batch_i,(img_paths, input_imgs)inenumerate(dataloader):# Configure inputinput_imgs = Variable(input_imgs.type(Tensor))# Get detectionswith torch.no_grad():detections = model(input_imgs)detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)# Log progresscurrent_time = time.time()inference_time = datetime.timedelta(seconds=current_time - prev_time)prev_time = current_timeprint("\t+ Batch %d, Inference Time: %s"%(batch_i, inference_time))# Save image and detectionsimgs.extend(img_paths)img_detections.extend(detections)# Bounding-box colorscmap = plt.get_cmap("tab20b")colors =[cmap(i)for i in np.linspace(0,1,20)]print("\nSaving images:")# Iterate through images and save plot of detectionsfor img_i,(path, detections)inenumerate(zip(imgs, img_detections)):print("(%d) Image: '%s'"%(img_i, path))# Create plotimg = np.array(Image.open(path))plt.figure()fig, ax = plt.subplots(1)ax.imshow(img)# Draw bounding boxes and labels of detectionsif detections isnotNone:# Rescale boxes to original imagedetections = rescale_boxes(detections, opt.img_size, img.shape[:2])unique_labels = detections[:,-1].cpu().unique()n_cls_preds =len(unique_labels)bbox_colors = random.sample(colors, n_cls_preds)for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:print("\t+ Label: %s, Conf: %.5f"%(classes[int(cls_pred)], cls_conf.item()))box_w = x2 - x1box_h = y2 - y1color = bbox_colors[int(np.where(unique_labels ==int(cls_pred))[0])]# Create a Rectangle patchbbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")print(int(box_w)*int(box_h))# if(box_w*box_h>10000):# se.write("1".encode())# time.sleep(3)# se.write("0".encode())# Add the bbox to the plotax.add_patch(bbox)# Add labelplt.text(x1,y1,s=classes[int(cls_pred)],color="white",verticalalignment="top",bbox={"color": color,"pad":0},)# Save generated image with detectionsplt.axis("off")plt.gca().xaxis.set_major_locator(NullLocator())plt.gca().yaxis.set_major_locator(NullLocator())filename = os.path.basename(path).split(".")[0]output_path = os.path.join("../output",f"{filename}.png")plt.savefig(output_path, bbox_inches="tight", pad_inches=0.0)plt.close()app.run()#啟動flask服務器
其中函數display():用于在瀏覽器輸入地址后直接返回圖片
@app.route('/img/',methods=['GET'])defdisplay():if request.method =='GET':# if filename is None:# pass# else:image =open("/static/he_21.png","rb").read()# image = open(igpath+filename,"rb").read()response = make_response(image)response.headers['Content-Type']='image/jpg'return response
除了最后一行app.run()其他的都是深度學習的代碼,學過的應該很容易看懂。
前端代碼
<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><title>Title</title><linkrel="stylesheet"type="text/css"href="static/Login.css"/></head><body><h1>檢測結果</h1><h1>hello!!!!!!!!!!!!```</h1><h2>this is the detection result</h2><imgsrc="/static/he_21.png"><imgsrc="/static/he_14.png"><imgsrc="/static/he_4.png"></body></html>