在Android设备部署PyTorch模型
Pytorch Mobile Android
- Demo 1 HelloWorldApp
- 1 模型準備
- 2 源碼分析
- 3 讀取圖片數據
- 4 讀取模型
- 5 將圖像轉換為Tensor
- 6 運行模型
- 7 處理結果
- Demo2 Pytorch Demo APP
- 1 攝像頭API
- 2 圖像分類
- 3 顯示結果
- Demo3 Image Segmentation
- Semantic Image Segmentation DeepLabV3 with Mobile Interpreter on Android
- 1.Prepare the Model
- 2.Use Android Studio
- 3.Run the app
- 參考文獻
現如今,在邊緣設備上運行機器學習/深度學習變得越來越流行,它需要更低的時延。
而從Pytorch 1.3開始,我們就可以使用Pytorch將模型部署到Android或者ios設備中。
Pytorch官方文檔:https://pytorch.org/mobile/home/
Pytorch官方文檔中提供關于Pytorch-mobile的Demo:https://github.com/pytorch/android-demo-app
主要包含了兩個APP應用,一個簡單的在神經網絡領域中的“hello world"項目,另一個就更復雜了一些,有圖形識別和語言識別。
我們接下來研究一下Pytorch Mobile的項目流程。
Demo 1 HelloWorldApp
1 模型準備
首先我們需要先訓練好的模型保存好。比如我在Pycharm寫了經典CNN模型MobileNet-v3。
import torch import torchvision from torch.utils.mobile_optimizer import optimize_for_mobilemodel = torchvision.models.mobilenet_v3_small(pretrained=True) model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) optimized_traced_model = optimize_for_mobile(traced_script_module) optimized_traced_model._save_for_lite_interpreter("./mobilenet_v3_small_model.pt")在 checkpoints/ 文件夾中保存了 mobilenet_v3_small_model.pt ,有了這個模型,我們就可以進行Android的部署了。
2 源碼分析
Clone 源碼
我們先在本地clone一下github上的源碼(吐槽一下git clone的速度,龜速!):
git clone https://github.com/pytorch/android-demo-app.git然后便得到這個項目。
前提先確保一下Android安裝好了SDK和NDK。
向 Gradle 添加依賴
然后我們會在 app 下的 build.gradle 中發現這樣的依賴:
最下面兩行中的
- org.pytorch:pytorch_android : Pytorch Android API 的主要依賴,包含為4個Android abis (armeabi-v7a, arm64-v8a, x86, x86_64) 的 libtorch 本地庫。
- org.pytorch:pytorch_android_torchvision :它是具有將 android.media.image 和 android.graphics.bitmap 轉換為 Tensor 的附加庫。
3 讀取圖片數據
在 MainActivity.java 文件中,有這么一行:
Bitmap 為位圖,其包括像素以及長、寬、顏色等描述信息。長、寬、像素位數用來描述圖片,并可以通過這些信息計算出圖片的像素占用內存的大小。
通過 BitmapFactory.decodeStream( ) 這一函數加載圖像。
4 讀取模型
同樣在 MainActivity.java文件中,有這么一行:
當然我們需要 import org.pytorch.Module
然后通過 Module 定義一個對象后使用 Module.load() 來讀取模型。
5 將圖像轉換為Tensor
org.pytorch.torchvision.TensorImageUtils 就是org.pytorch:pytorch_android_torchvision庫中的一部分,TensorImageUtils.bitmapToFloat32Tensor 創建一個Tensor類型。
inputTensor 的 大小為 1x3xHxW, 其中 H 和 W 分別為 Bitmap 的高和寬。
6 運行模型
將 inputTensor 放到模型中運行,通過 module.forward() 得到一個 outputTensor。
7 處理結果
// getting tensor content as java array of floatsfinal float[] scores = outputTensor.getDataAsFloatArray();// searching for the index with maximum scorefloat maxScore = -Float.MAX_VALUE;int maxScoreIdx = -1;for (int i = 0; i < scores.length; i++) {if (scores[i] > maxScore) {maxScore = scores[i];maxScoreIdx = i;}}String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];// showing className on UITextView textView = findViewById(R.id.text);textView.setText(className);判斷最高分數,并將結果顯示到textView中。
Demo2 Pytorch Demo APP
這是另一個Demo App,它可以進行圖像分類和文字分類。而圖像分類就需要利用攝像頭。
1 攝像頭API
攝像頭API通過使用 org.pytorch.demo.vision.AbstractCameraXActivity 類。
在 AbstractCameraXActivity.java 中的具體源碼如下:
2 圖像分類
而在 ImageClassificationActivity.java 中的源碼如下:
protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {if (mAnalyzeImageErrorState) {return null;}try {if (mModule == null) {final String moduleFileAbsoluteFilePath = new File(Utils.assetFilePath(this, getModuleAssetName())).getAbsolutePath();// 導入模型mModule = Module.load(moduleFileAbsoluteFilePath);mInputTensorBuffer =Tensor.allocateFloatBuffer(3 * INPUT_TENSOR_WIDTH * INPUT_TENSOR_HEIGHT);mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, INPUT_TENSOR_HEIGHT, INPUT_TENSOR_WIDTH});}final long startTime = SystemClock.elapsedRealtime();// 將以YUV420形式的Image類型轉化為輸入TensorTensorImageUtils.imageYUV420CenterCropToFloatBuffer(image.getImage(), rotationDegrees,INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT,TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,TensorImageUtils.TORCHVISION_NORM_STD_RGB,mInputTensorBuffer, 0);final long moduleForwardStartTime = SystemClock.elapsedRealtime();// 利用模型進行運算final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;// 從模型中得到預測分數final float[] scores = outputTensor.getDataAsFloatArray();// 找到得分最高的前k個類final int[] ixs = Utils.topK(scores, TOP_K);final String[] topKClassNames = new String[TOP_K];final float[] topKScores = new float[TOP_K];for (int i = 0; i < TOP_K; i++) {final int ix = ixs[i];topKClassNames[i] = Constants.IMAGENET_CLASSES[ix];topKScores[i] = scores[ix];}final long analysisDuration = SystemClock.elapsedRealtime() - startTime;return new AnalysisResult(topKClassNames, topKScores, moduleForwardDuration, analysisDuration);} catch (Exception e) {Log.e(Constants.TAG, "Error during image analysis", e);mAnalyzeImageErrorState = true;runOnUiThread(() -> {if (!isFinishing()) {showErrorDialog(v -> ImageClassificationActivity.this.finish());}});return null;}}3 顯示結果
最后將得到的前k個類加載到UI上。
protected void applyToUiAnalyzeImageResult(AnalysisResult result) {mMovingAvgSum += result.moduleForwardDuration;mMovingAvgQueue.add(result.moduleForwardDuration);if (mMovingAvgQueue.size() > MOVING_AVG_PERIOD) {mMovingAvgSum -= mMovingAvgQueue.remove();}for (int i = 0; i < TOP_K; i++) {final ResultRowView rowView = mResultRowViews[i];rowView.nameTextView.setText(result.topNClassNames[i]);rowView.scoreTextView.setText(String.format(Locale.US, SCORES_FORMAT,result.topNScores[i]));rowView.setProgressState(false);}mMsText.setText(String.format(Locale.US, FORMAT_MS, result.moduleForwardDuration));if (mMsText.getVisibility() != View.VISIBLE) {mMsText.setVisibility(View.VISIBLE);}mFpsText.setText(String.format(Locale.US, FORMAT_FPS, (1000.f / result.analysisDuration)));if (mFpsText.getVisibility() != View.VISIBLE) {mFpsText.setVisibility(View.VISIBLE);}if (mMovingAvgQueue.size() == MOVING_AVG_PERIOD) {float avgMs = (float) mMovingAvgSum / MOVING_AVG_PERIOD;mMsAvgText.setText(String.format(Locale.US, FORMAT_AVG_MS, avgMs));if (mMsAvgText.getVisibility() != View.VISIBLE) {mMsAvgText.setVisibility(View.VISIBLE);}}}Demo3 Image Segmentation
Semantic Image Segmentation DeepLabV3 with Mobile Interpreter on Android
This repo offers a Python script that converts the PyTorch DeepLabV3 model to the Lite Interpreter version of model, also optimized for mobile, and an Android app that uses the model to segment images.
1.Prepare the Model
import torch from torch.utils.mobile_optimizer import optimize_for_mobile# 加載訓練好的模型 model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet50', pretrained=True) # 設置為推理模式 model.eval()# 將訓練好的模型轉換為jit腳本模型 scripted_module = torch.jit.script(model) # 優化jit腳本模型,提高在移動設備上的推理性能 optimized_scripted_module = optimize_for_mobile(scripted_module)# 導出完整的jit版本模型(不兼容輕量化解釋器) scripted_module.save("deeplabv3_scripted.pt") # 導出輕量化解釋器版本模型(與輕量化解釋器兼容) scripted_module._save_for_lite_interpreter("deeplabv3_scripted.ptl") # 使用優化的輕量化解釋器模型比未優化的輕量化解釋器模型推理速度快60%左右,比未優化的jit腳本模型推理速度快6%左右 optimized_scripted_module._save_for_lite_interpreter("deeplabv3_scripted_optimized.ptl")2.Use Android Studio
使用Android Studio打開ImageSegment項目。注意應用程序的build.gradle文件有以下行:
implementation 'org.pytorch:pytorch_android_lite:1.9.0' implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'在MainActive . java中,下面的代碼用于加載模型:
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "deeplabv3_scripted_optimized.ptl"));3.Run the app
參考文獻
總結
以上是生活随笔為你收集整理的在Android设备部署PyTorch模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python all()函数
- 下一篇: pyinstaller使用方法及案例