我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

近期需要将pytorch模型运行到android手机上实验,在查阅网上博客后,发现大多数流程需要借助多个框架或软件,横跨多个编程语言、IDE。本文参考以下两篇博文,力求用更简洁的流程实现模型部署。

https://blog.csdn.net/xiaodidididi521/article/details/123985612
https://blog.csdn.net/m0_67391683/article/details/125401357

向两位作者表示感谢!本文进一步详细描述了实现流程。

一、pytorch模型转化

pytorch模型无法直接被Android调用,需要转化为特定格式.pt。本文使用pycharm IDE完成这一步,工程目录结构如下:
![pycharm目录结构](https://img-blog.csdnimg.cn/d67266301c3f43bfa20d3585dc5fe836.png#pic_center
其中,vgg16bn_CIFAR10.pth和另一个pth文件是需要部署到手机上的模型,models.py是自己定义的网络结构。在此默认读者熟悉pytorch,对models.py不做赘述。
pycharm目录结构

执行以下代码实现转换:

import torch.utils.data.distributed'定义转化后的模型名称'model_ori_pt ='model_ori.pt'model_pruned_pt ='model_pruned.pt''加载pytorch模型'model_ori = torch.load('vgg16bn_CIFAR10.pth')model_pruned = torch.load('vgg16bn_CIFAR10_pruned.pth')'模型在cpu上运行'device = torch.device('cpu')model_ori.to(device)model_pruned.to(device)model_ori.eval()model_pruned.eval()'定义输入图片的大小'input_tensor = torch.rand(1, 3, 32, 32)'转化模型并存储'mobile_ori = torch.jit.trace(model_ori, input_tensor)model_pruned = torch.jit.trace(model_pruned, input_tensor)mobile_ori.save(model_ori_pt)model_pruned.save(model_pruned_pt)

请注意,让模型在cpu上,或cuda上执行eval()均可,但要保证模型与input_tensor在同一设备上,否则将运行出错。运行后,会得到model_ori.ptmodel_pruned.pt两个文件,即可以用于android上的文件。此时目录结构如下:
在这里插入图片描述

二、新建Android Studio工程

首先,需要在本地安装Android Studio,安装流程建议参照:

https://m.runoob.com/android/android-studio-install.html?ivk_sa=1024320u
然后打开Android Studio新建Empy Activity
在这里插入图片描述

点击Next。
在这里插入图片描述

点击Finsh。SDK建议选择7.0以往的安卓版本。**首次新建工程底部会长时间出现加载进度条,请耐心等待加载完成。**接下来,我们需要有一部手机调试工程,本文使用Android Studio自带的模拟器。首先点击顶部工具栏的Device Manager。
在这里插入图片描述
点击create device
在这里插入图片描述
接下来选择机型、安卓版本、内存等,如不想麻烦可一直点击next。
在这里插入图片描述
finsh后,Android Studio需要下载安卓版本包,需要耐心等待。下载完成后即可启动虚拟机。

在这里插入图片描述
再shift+F10即可在模拟机里运行程序。
在这里插入图片描述

三、转化后的模型部署安卓

首先,新建assets文件夹,请不要直接新建,需右键app->Folder->Assets Folder。
在这里插入图片描述
之后将转化好的两个模型及侧视图放入assets文件夹。本文使用的是CIFAR10数据集,可在以下网址下载:

http://www.cs.toronto.edu/~kriz/cifar.html
然后在gradle Scripts 文件夹中的build.gradle(Module :app)文件中的depencies里添加:

implementation 'org.pytorch:pytorch_android:1.12.1'implementation 'org.pytorch:pytorch_android_torchvision:1.12.1'

请注意**1.12.1是本文使用的pytorch版本,读者应该为对应的版本号。**然后点击工具栏下的sync now,再耐心等待运行按钮变绿。
在这里插入图片描述
双击res->layout->activity_main.xml并切换到code。
在这里插入图片描述
删除所有代码,复制以下代码段:

<?xml version="1.0" encoding="utf-8"?><FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"    xmlns:tools="http://schemas.android.com/tools"    android:layout_width="match_parent"    android:layout_height="match_parent"    tools:context=".MainActivity">    <ImageView        android:id="@+id/image"        android:layout_width="match_parent"        android:layout_height="match_parent"        android:scaleType="fitCenter" />    <TextView        android:id="@+id/text"        android:layout_width="match_parent"        android:layout_height="wrap_content"        android:layout_gravity="top"        android:textSize="24sp"        android:textColor="@android:color/holo_red_light" /></FrameLayout>

然后右键java里的com.example.工程名 文件夹,New->Java Class。本文新建的类名是CIfarClassed,类内代码:

package com.example.工程名;public class CifarClassed {    public static String[] IMAGENET_CLASSES = new String[]{            "ddd",            "automobile",            "bird",            "cat",            "deer",            "dog",            "frog",            "horse",            "ship",            "truck",    };}

最后打开java->com.example.工程名->MainActivity,删除原代码,用以下代码替代:

package com.example.dnna;import android.content.Context;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.os.Bundle;import android.util.Log;import android.widget.ImageView;import android.widget.TextView;import org.pytorch.IValue;import org.pytorch.Module;import org.pytorch.Tensor;import org.pytorch.torchvision.TensorImageUtils;import org.pytorch.MemoryFormat;import java.io.File;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStream;import java.io.OutputStream;import androidx.appcompat.app.AppCompatActivity;import com.example.dnna.CifarClassed;public class MainActivity extends AppCompatActivity {    @Override    protected void onCreate(Bundle savedInstanceState) {        super.onCreate(savedInstanceState);        setContentView(R.layout.activity_main);        Bitmap bitmap = null;        Module module_ori = null;        Module module_pruned = null;        try {            // creating bitmap from packaged into app android asset 'image.jpg',            // app/class="lazy" data-src/main/assets/image.jpg            bitmap = BitmapFactory.decodeStream(getAssets().open("x.png"));            // loading serialized torchscript module from packaged into app android asset model.pt,            // app/class="lazy" data-src/model/assets/model.pt            module_ori = Module.load(assetFilePath(this, "model_ori.pt"));            module_pruned = Module.load(assetFilePath(this, "model——pruned.pt"));        } catch (IOException e) {            Log.e("PytorchHelloWorld", "Error reading assets", e);            finish();        }        // showing image on UI        ImageView imageView = findViewById(R.id.image);        imageView.setImageBitmap(bitmap);        // preparing input tensor        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);        // running the model        long startTime_ori = System.currentTimeMillis();        final Tensor outputTensor_ori = module_ori.forward(IValue.from(inputTensor)).toTensor();        long endTime_ori = System.currentTimeMillis();        long InferenceTimeOri=endTime_ori - startTime_ori;        long startTime_pruned = System.currentTimeMillis();        final Tensor outputTensor_pruned = module_pruned.forward(IValue.from(inputTensor)).toTensor();        long endTime_pruned = System.currentTimeMillis();        long InferenceTimePruned=endTime_pruned - startTime_pruned;        // getting tensor content as java array of floats        final float[] scores = outputTensor_ori.getDataAsFloatArray();        // searching for the index with maximum score        float maxScore = -Float.MAX_VALUE;        int maxScoreIdx = -1;        for (int i = 0; i < scores.length; i++) {            if (scores[i] > maxScore) {                maxScore = scores[i];                maxScoreIdx = i;            }        }        System.out.println(maxScoreIdx);        String className = CifarClassed.IMAGENET_CLASSES[maxScoreIdx];        // showing className on UI        TextView textView = findViewById(R.id.text);        String tex="推理结果:"+className+"\n原始模型推理时间:"+InferenceTimeOri+"ms"+"\n剪枝模型推理时间:"+InferenceTimePruned+"ms";        textView.setText(tex);    }        public static String assetFilePath(Context context, String assetName) throws IOException {        File file = new File(context.getFilesDir(), assetName);        if (file.exists() && file.length() > 0) {            return file.getAbsolutePath();        }        try (InputStream is = context.getAssets().open(assetName)) {            try (OutputStream os = new FileOutputStream(file)) {                byte[] buffer = new byte[4 * 1024];                int read;                while ((read = is.read(buffer)) != -1) {                    os.write(buffer, 0, read);                }                os.flush();            }            return file.getAbsolutePath();        }    }}

运行效果如下图:

在这里插入图片描述

四、结语

本文的主要流程是:

  • 使用pytorch转化模型
  • 新建Android Studio工程与虚拟机
  • 修改Android Studio工程代码

本人目前希望提升自己的博客撰写水平,如读者在实现过程中遇到困难,或在阅读本文时感到困惑,欢迎留言或添加我的QQ:1106295085。我将在周日下午回复,并积极修改本文。

来源地址:https://blog.csdn.net/qq_39068200/article/details/129231207

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

下载Word文档到电脑,方便收藏和打印~

下载Word文档

编程热搜

  • Android:VolumeShaper
    VolumeShaper(支持版本改一下,minsdkversion:26,android8.0(api26)进一步学习对声音的编辑,可以让音频的声音有变化的播放 VolumeShaper.Configuration的三个参数 durati
    Android:VolumeShaper
  • Android崩溃异常捕获方法
    开发中最让人头疼的是应用突然爆炸,然后跳回到桌面。而且我们常常不知道这种状况会何时出现,在应用调试阶段还好,还可以通过调试工具的日志查看错误出现在哪里。但平时使用的时候给你闹崩溃,那你就欲哭无泪了。 那么今天主要讲一下如何去捕捉系统出现的U
    Android崩溃异常捕获方法
  • android开发教程之获取power_profile.xml文件的方法(android运行时能耗值)
    系统的设置–>电池–>使用情况中,统计的能耗的使用情况也是以power_profile.xml的value作为基础参数的1、我的手机中power_profile.xml的内容: HTC t328w代码如下:
    android开发教程之获取power_profile.xml文件的方法(android运行时能耗值)
  • Android SQLite数据库基本操作方法
    程序的最主要的功能在于对数据进行操作,通过对数据进行操作来实现某个功能。而数据库就是很重要的一个方面的,Android中内置了小巧轻便,功能却很强的一个数据库–SQLite数据库。那么就来看一下在Android程序中怎么去操作SQLite数
    Android SQLite数据库基本操作方法
  • ubuntu21.04怎么创建桌面快捷图标?ubuntu软件放到桌面的技巧
    工作的时候为了方便直接打开编辑文件,一些常用的软件或者文件我们会放在桌面,但是在ubuntu20.04下直接直接拖拽文件到桌面根本没有效果,在进入桌面后发现软件列表中的软件只能收藏到面板,无法复制到桌面使用,不知道为什么会这样,似乎并不是很
    ubuntu21.04怎么创建桌面快捷图标?ubuntu软件放到桌面的技巧
  • android获取当前手机号示例程序
    代码如下: public String getLocalNumber() { TelephonyManager tManager =
    android获取当前手机号示例程序
  • Android音视频开发(三)TextureView
    简介 TextureView与SurfaceView类似,可用于显示视频或OpenGL场景。 与SurfaceView的区别 SurfaceView不能使用变换和缩放等操作,不能叠加(Overlay)两个SurfaceView。 Textu
    Android音视频开发(三)TextureView
  • android获取屏幕高度和宽度的实现方法
    本文实例讲述了android获取屏幕高度和宽度的实现方法。分享给大家供大家参考。具体分析如下: 我们需要获取Android手机或Pad的屏幕的物理尺寸,以便于界面的设计或是其他功能的实现。下面就介绍讲一讲如何获取屏幕的物理尺寸 下面的代码即
    android获取屏幕高度和宽度的实现方法
  • Android自定义popupwindow实例代码
    先来看看效果图:一、布局
  • Android第一次实验
    一、实验原理 1.1实验目标 编程实现用户名与密码的存储与调用。 1.2实验要求 设计用户登录界面、登录成功界面、用户注册界面,用户注册时,将其用户名、密码保存到SharedPreference中,登录时输入用户名、密码,读取SharedP
    Android第一次实验

目录