はやし雑記

はやしです

onnxjsを使ってブラウザで機械学習モデルを実行する

あけましておめでとうございます。 今年もよろしくお願いします。

今年の年末年始は帰省しなかったので、東京で一人で過ごしていました。 一週間以上誰とも会わず、喋らず、ほとんど外に出ずでした。

久しぶりにまとまった時間があったので、この年末年始は機械学習系の諸々を触っていました。 色々とやっていたのですが、onnxとそのjs実装であるonnxjsが結構面白かったのでご紹介します。

ONNX

ONNX (Open Neural Network Exchange) は機械学習モデルを表現するためのフォーマットのことで、 PyTorchやKerasなどのモデルをONNXフォーマットに変換することができます。

onnx.ai

github.com

ja.wikipedia.org

ONNXには様々なOperatorが定義されています。

ONNXは日々更新されていますが、Opset Versionで管理されています。

onnx/Operators.md at master · onnx/onnx · GitHub

例えば、Abs はバージョン1から使えますが、Acoshはバージョン9以降のOpset Versionでないと使えません。 2021年1月3日現在、バージョン13まできています。 また、pytorchからの変換でのopset_versionはデフォルトでは9です。

onnxjs

onnxjsはONNXのjs実装です。

github.com

Backendとして、cpuwebglwasmの3つに対応しています。

cpuは普通にJavascriptでの計算し、webglはWebGLで、wasmはWebAssemblyで計算します。

実装自体はそれぞれ別なので、Backendによって実装されているOperatorは違います。

onnxjs/operators.md at master · microsoft/onnxjs · GitHub

正直なところ、まだまだ実装されているOperatorは少なく、wasmで使えるものはあまりないです。

今回作ったデモアプリ

f:id:hayashikunsan:20210103232720p:plain

今回は定番の手書き数字を判別するやーつを実装しました。

github.com

デプロイしたものはこちらです。

https://demo4onnxjs.hayashikun.com/#/mnist

機械学習フレームワークにはPytorchを使いました。 フロントエンドはReactで実装しています。

今回、足りていないOperatorがあったので自分で実装しました。 なので、onnxjsはローカルのものをみています。

demo4onnxjs/package.json at master · hayashikun/demo4onnxjs · GitHub

機械学習部分

demo4onnxjs/mnist.py at master · hayashikun/demo4onnxjs · GitHub

Net1とNet2の2種類のモデルを用意しています。 それぞれネットワークの構造が違います。

MNISTの手書き数字は28x28の1チャンネル画像なので、入力は(batch_size)x(channel=1)x(width=28)x(height=28)になります。

それぞのモデルは軽く学習させたあと、onnxに変換されます。

Pytorchのモデルからonnxへの変換は、

x = torch.zeros(1, 1, 28, 28)
torch.onnx.export(model, x, "model.onnx", opset_version=opset_version)

でできます。 xは入力で、zerosでもrandomでもなんでも良いです。

今回のアプリケーションでは、/models/data下に学習済みのモデルが出力されます。

onnx形式に変換したモデルは、Netronというソフトを使って可視化できます。

f:id:hayashikunsan:20210104010403p:plain

copy-files.js

学習済みモデルはcopy-files.js/publicにコピーされます。

demo4onnxjs/copy-files.js at master · hayashikun/demo4onnxjs · GitHub

また、WebAssemblyを動かすために、onnx-wasm.wasmなどが必要なので、それも/publicにコピーしています。

手書き&認識

マウスで手書きするために、DrawableCanvasというComponentを実装しました。

demo4onnxjs/DrawableCanvas.tsx at master · hayashikun/demo4onnxjs · GitHub

入力は28x28でないといけないので、非表示の28x28のCanvasに描いた内容をコピーして取得しています。

書き込んだ数字を認識させるためには、モデルを読み込む必要があります。

読み込むモデルとバックエンド、Opset Versionを選んだらLoadを押してモデルを読み込みます。

一部の組み合わせ、例えばNet2 & webgl & version 9では、 Error: TypeError: cannot resolve operator 'LogSoftmax' with opsets: ai.onnx v9 のようなエラーが出ます。

これはLogSoftmaxwebglのbackendでまだ実装されていないからです。 実はLogSoftmaxはまだcpuでも実装されていないので、自分で実装しました。(PR投げた)

モデルを読み込んだら、Evalボタンを押せば結果が出力されます。

表示されるのは確率で、一番高いやつが赤字です。

正直あんまり認識は良くないです。

終わり

ちょっともうちょい見た目とか直したい。

実はPGGANでアイドルの顔写真生成とかも試しているので、それについても書きたい。

f:id:hayashikunsan:20210103234051p:plain

onnxjsにはまだまだOperatorが足りていないので、自分でもっと色々追加していきたいな〜