PyTorch MobileのAndroid版 HelloWorldを試す
1. PyTorch Mobile
「PyTorch 1.3」では、実験的リリースですが、iOSとAndroidをサポートするようになりました。
2. HelloWorld
AndroidでPyTorch Android APIを使用するシンプルな画像分類アプリケーション「HelloWorld」が提供されています。
今回はこれを実行してみます。
3. モデルの準備
事前訓練された画像分類モデルである「Resnet18」を使用します。
これは、「TorchVision」にパッケージ化されています。
(1)Anacondaなどの仮想環境で次のコマンドを実行し、「TorchVision」をインストール。
$ pip install torchvision
(2)HelloWorldをダウンロードしてHelloWorldフォルダに移動し、「trace_model.py」を実行。
Androidプロジェクトのassetsフォルダ(HelloWorldApp/app/src/main/assets)に、モバイルで実行できるTorchScriptモデル「model.pt」が生成されます。
$ python trace_model.py
4. PyTorch Android APIのインストール
「PyTorch Android API」は、Android Studioのgradleでインストールできます。
(1)Anroid StudioでAndroidプロジェクトを開く。
「Android Studio」でAndroidプロジェクト(HelloWorldApo)を開いてください。
(2)「Android SDK」と「Android NDK」をまだインストールしてない場合は、「Android Studio」を使ってインストール。
(3)gradleに以下の依存関係を指定。
HelloWorldは指定済みのため、プロジェクトを開くだけで「PyTorch Android API」が自動的にインストールされます。
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.3.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}
「org.pytorch:pytorch_android」は、「PyTorch Android API」の主な依存関係です。これには、4つのAndroid Abis(armeabi-v7a、arm64-v8a、x86、x86_64)のlibtorchネイティブライブラリが含まれます。
5. HelloWorldの実行
アプリを実行してください。画面に予測結果が表示される...予定でしたが、module.forward()の戻り値が戻らず、エラーなしで止まったまま。
今日のところはあきらめる。
6. コードの説明
コードをステップごとに説明します。
◎画像の読み込み
はじめに、画像の読み込みを行います。
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
◎学習済みTorchScriptモデルの読み込み
次に、学習済みTorchScriptモデルを読み込みます。
Module module = Module.load(assetFilePath(this, "model.pt"));
「org.pytorch.Module」は、シリアル化されたモデルのファイルパスを指定するload()でロードできる「torch::jit::script::Module」を示します。
7. 入力の準備
次に、入力の準備を行います。
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
「org.pytorch.torchvision.TensorImageUtils」は「org.pytorch:pytorch_android_torchvisionライブラリ」の一部です。
TensorImageUtils#bitmapToFloat32Tensor()は、android.graphics.Bitmapをソースとして使用して、TorchVision形式でテンソルを作成します。
inputTensorの形状は1x3xHxWです。HとWはビットマップの高さと幅です。
8. 推論の実行
次に、推論を実行して結果を取得します。
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
org.pytorch.Module.forward()は、ロードされたモジュールのforward()を実行し、形状1x1000の「org.pytorch.Tensor outputTensor」として結果を取得します。
9. 結果の処理
結果は、org.pytorch.Tensor.getDataAsFloatArray()を使用して取得します。全クラスのスコアを持つfloat型のJava配列になります。そして、最大スコアのインデックスを見つけて、ImageNetClasses.IMAGENET_CLASSES配列からクラス名を取得します。
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
10. APIリファレンス
「PyTorch Android API」の詳細については、Javadocを参照してください。
この記事が気に入ったらサポートをしてみませんか?