使用tensorflow serving部署keras模型(tensorflow 2.0.0)
點擊上方“AI搞事情”關注我們
內容轉載自知乎:https://zhuanlan.zhihu.com/p/96917543
Justin ho
〉
Tensorflow 2.0.0出來后,1.x版本的API有些已經改變,19年年初寫的這一篇《TensorFlow Serving + Docker + Tornado機器學習模型生產級快速部署》?文章,在tf 2.0.0版本里面有較大的變動,另外Tensorflow官方也推薦大家使用tf.keras,因此本文將會教大家如何使用tensorflow serving部署keras模型,適用tensorflow 2.0.0以后的版本。注:下面“tensorflow serving”將會簡寫為“tfs”。
一、導出Keras模型
keras模型訓練完畢后,一般我們都會使用model.save(filepath)儲存為h5文件,包含模型的結構和參數,而我們需要把這個h5文件導出為tensorflow serving所需要的模型格式:
from keras import backend as K from keras.models import load_model import tensorflow as tf# 首先使用tf.keras的load_model來導入模型h5文件 model_path = 'v7_resnet50_19-0.9068-0.8000.h5' model = tf.keras.models.load_model(model_path, custom_objects=dependencies) model.save('models/resnet/', save_format='tf') # 導出tf格式的模型文件注意,這里要使用tf.keras.models.load_model來導入模型,不能使用keras.models.load_model,只有tf.keras.models.load_model能導出成tfs所需的模型文件。導出的文件路徑結構如下:
. └── 0├── assets├── saved_model.pb└── variables├── variables.data-00000-of-00002├── variables.data-00001-of-00002└── variables.index最大的改變是,以往導出keras模型需要寫一大段定義builder的代碼,如這篇文章《keras、tensorflow serving踩坑記》?的那樣,現在只需使用簡單的model.save就可以導出了。當然以前這種寫法應該還能繼續使用,能自定義signature、input、output的名稱,但為了簡單起見(用keras不就是為了簡單嘛),直接一鍵導出更舒服。
導出以后,我們在終端執行以下命令,查看模型的signature、input、output的名稱,后面要用到:
saved_model_cli show --dir 0/ --all# 將會輸出: MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:signature_def['__saved_model_init_op']:The given SavedModel SignatureDef contains the following input(s):The given SavedModel SignatureDef contains the following output(s):outputs['__saved_model_init_op'] tensor_info:dtype: DT_INVALIDshape: unknown_rankname: NoOpMethod name is:signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['input_1'] tensor_info:dtype: DT_FLOATshape: (-1, 224, 224, 3)name: serving_default_input_1:0The given SavedModel SignatureDef contains the following output(s):outputs['fc2'] tensor_info:dtype: DT_FLOATshape: (-1, 4)name: StatefulPartitionedCall:0Method name is: tensorflow/serving/predict可以看到,signature name為“serving_default”,input name為“input_1”,output name為“fc2”。記下來,一會用到。
二、Docker部署模型
一律使用docker來部署你的模型,如果還不知道docker是什么,不知道怎么用docker來拉取tfs的鏡像,請查閱我之前的文章:
Justin ho:TensorFlow Serving + Docker + Tornado機器學習模型生產級快速部署zhuanlan.zhihu.com
這里默認你已經會拉取tfs模型到本地,現在執行以下容器啟動命令:
sudo nvidia-docker run -p 8500:8500 \-v /home/projects/resnet/weights/:/models \--name resnet50 \-itd --entrypoint=tensorflow_model_server tensorflow/serving:2.0.0-gpu \--port=8500 --per_process_gpu_memory_fraction=0.5 \--enable_batching=true --model_name=resnet --model_base_path=/models/season &這里涉及的參數意義一律看上面那篇文章,這里不解釋了。
三、請求客戶端
模型部署起來后,我們要寫一個grpc客戶端來請求模型,代碼參考:
from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc import grpcdef request_server(img_resized, server_url):'''用于向TensorFlow Serving服務請求推理結果的函數。:param img_resized: 經過預處理的待推理圖片數組,numpy array,shape:(h, w, 3):param server_url: TensorFlow Serving的地址加端口,str,如:'0.0.0.0:8500' :return: 模型返回的結果數組,numpy array'''# Request.channel = grpc.insecure_channel(server_url)stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)request = predict_pb2.PredictRequest()request.model_spec.name = "resnet" # 模型名稱,啟動容器命令的model_name參數request.model_spec.signature_name = "serving_default" # 簽名名稱,剛才叫你記下來的# "input_1"是你導出模型時設置的輸入名稱,剛才叫你記下來的request.inputs["input_1"].CopyFrom(tf.make_tensor_proto(img_resized, shape=[1, ] + list(img_resized.shape)))response = stub.Predict(request, 5.0) # 5 secs timeoutreturn np.asarray(response.outputs["fc2"].float_val) # fc2為輸出名稱,剛才叫你記下來的tensorflow 2.0.0請使用以上這段代碼,之前那篇tfs的部署文章里面的api已經變了。
我們測試一下,讀取一張圖片,發送請求到tfs:
from PIL import Image import numpy as npimgpath = '20171101110450_48901.jpg' x = Image.open(imgpath) x = np.array(x).astype('float32') x = (x - 128.) / 128.# grpc地址及端口,為你容器所在機器的ip + 容器啟動命令里面設置的port server_url = '0.0.0.0:8500' request_server(x, server_url)我們將會得到(我這里的resnet只輸出4類多標簽結果):
array([0.58116215, 0.04240507, 0.74790353, 0.1388033 ])
總結
以上是生活随笔為你收集整理的使用tensorflow serving部署keras模型(tensorflow 2.0.0)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【推荐】一款快速预览神器:QuickLo
- 下一篇: TensorFlow Serving +