見出し画像

PyTorch MobileのAndroid版 HelloWorldを試す

1. PyTorch Mobile

「PyTorch 1.3」では、実験的リリースですが、iOSとAndroidをサポートするようになりました。

2. HelloWorld

AndroidでPyTorch Android APIを使用するシンプルな画像分類アプリケーション「HelloWorld」が提供されています。

今回はこれを実行してみます。

3. モデルの準備

事前訓練された画像分類モデルである「Resnet18」を使用します。
これは、「TorchVision」にパッケージ化されています。

(1)Anacondaなどの仮想環境で次のコマンドを実行し、「TorchVision」をインストール。

$ pip install torchvision

(2)HelloWorldをダウンロードしてHelloWorldフォルダに移動し、「trace_model.py」を実行。

Androidプロジェクトのassetsフォルダ(HelloWorldApp/app/src/main/assets)に、モバイルで実行できるTorchScriptモデル「model.pt」が生成されます。

$ python trace_model.py

4. PyTorch Android APIのインストール

「PyTorch Android API」は、Android Studioのgradleでインストールできます。

(1)Anroid StudioでAndroidプロジェクトを開く。
「Android Studio」でAndroidプロジェクト(HelloWorldApo)を開いてください。

(2)「Android SDK」と「Android NDK」をまだインストールしてない場合は、「Android Studio」を使ってインストール。

(3)gradleに以下の依存関係を指定。
HelloWorldは指定済みのため、プロジェクトを開くだけで「PyTorch Android API」が自動的にインストールされます。

repositories {
   jcenter()
}

dependencies {
   implementation 'org.pytorch:pytorch_android:1.3.0'
   implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}

「org.pytorch:pytorch_android」は、「PyTorch Android API」の主な依存関係です。これには、4つのAndroid Abis(armeabi-v7a、arm64-v8a、x86、x86_64)のlibtorchネイティブライブラリが含まれます。

5. HelloWorldの実行

アプリを実行してください。画面に予測結果が表示される...予定でしたが、module.forward()の戻り値が戻らず、エラーなしで止まったまま。

今日のところはあきらめる。

6. コードの説明

コードをステップごとに説明します。

◎画像の読み込み
はじめに、画像の読み込みを行います。

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

◎学習済みTorchScriptモデルの読み込み
次に、学習済みTorchScriptモデルを読み込みます。

Module module = Module.load(assetFilePath(this, "model.pt"));

「org.pytorch.Module」は、シリアル化されたモデルのファイルパスを指定するload()でロードできる「torch::jit::script::Module」を示します。

7. 入力の準備

次に、入力の準備を行います。

Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
   TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

「org.pytorch.torchvision.TensorImageUtils」は「org.pytorch:pytorch_android_torchvisionライブラリ」の一部です。
TensorImageUtils#bitmapToFloat32Tensor()は、android.graphics.Bitmapをソースとして使用して、TorchVision形式でテンソルを作成します。
inputTensorの形状は1x3xHxWです。HとWはビットマップの高さと幅です。

8. 推論の実行

次に、推論を実行して結果を取得します。

Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

org.pytorch.Module.forward()は、ロードされたモジュールのforward()を実行し、形状1x1000の「org.pytorch.Tensor outputTensor」として結果を取得します。

9. 結果の処理

結果は、org.pytorch.Tensor.getDataAsFloatArray()を使用して取得します。全クラスのスコアを持つfloat型のJava配列になります。そして、最大スコアのインデックスを見つけて、ImageNetClasses.IMAGENET_CLASSES配列からクラス名を取得します。

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
 if (scores[i] > maxScore) {
   maxScore = scores[i];
   maxScoreIdx = i;
 }
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

10. APIリファレンス

「PyTorch Android API」の詳細については、Javadocを参照してください。

この記事が気に入ったらサポートをしてみませんか?