あけましておめでとうございます。 今年もよろしくお願いします。
今年の年末年始は帰省しなかったので、東京で一人で過ごしていました。 一週間以上誰とも会わず、喋らず、ほとんど外に出ずでした。
久しぶりにまとまった時間があったので、この年末年始は機械学習系の諸々を触っていました。 色々とやっていたのですが、onnxとそのjs実装であるonnxjsが結構面白かったのでご紹介します。
ONNX
ONNX (Open Neural Network Exchange) は機械学習モデルを表現するためのフォーマットのことで、 PyTorchやKerasなどのモデルをONNXフォーマットに変換することができます。
ONNXには様々なOperatorが定義されています。
ONNXは日々更新されていますが、Opset Versionで管理されています。
onnx/Operators.md at master · onnx/onnx · GitHub
例えば、Abs
はバージョン1から使えますが、Acosh
はバージョン9以降のOpset Versionでないと使えません。
2021年1月3日現在、バージョン13まできています。
また、pytorchからの変換でのopset_versionはデフォルトでは9です。
onnxjs
onnxjsはONNXのjs実装です。
Backendとして、cpu
とwebgl
とwasm
の3つに対応しています。
cpu
は普通にJavascriptでの計算し、webgl
はWebGLで、wasm
はWebAssemblyで計算します。
実装自体はそれぞれ別なので、Backendによって実装されているOperatorは違います。
onnxjs/operators.md at master · microsoft/onnxjs · GitHub
正直なところ、まだまだ実装されているOperatorは少なく、wasm
で使えるものはあまりないです。
今回作ったデモアプリ
今回は定番の手書き数字を判別するやーつを実装しました。
デプロイしたものはこちらです。
https://demo4onnxjs.hayashikun.com/#/mnist
機械学習フレームワークにはPytorchを使いました。 フロントエンドはReactで実装しています。
今回、足りていないOperatorがあったので自分で実装しました。
なので、onnxjs
はローカルのものをみています。
demo4onnxjs/package.json at master · hayashikun/demo4onnxjs · GitHub
機械学習部分
demo4onnxjs/mnist.py at master · hayashikun/demo4onnxjs · GitHub
Net1とNet2の2種類のモデルを用意しています。 それぞれネットワークの構造が違います。
MNISTの手書き数字は28x28の1チャンネル画像なので、入力は(batch_size)x(channel=1)x(width=28)x(height=28)になります。
それぞのモデルは軽く学習させたあと、onnxに変換されます。
Pytorchのモデルからonnxへの変換は、
x = torch.zeros(1, 1, 28, 28) torch.onnx.export(model, x, "model.onnx", opset_version=opset_version)
でできます。 xは入力で、zerosでもrandomでもなんでも良いです。
今回のアプリケーションでは、/models/data
下に学習済みのモデルが出力されます。
onnx形式に変換したモデルは、Netronというソフトを使って可視化できます。
copy-files.js
学習済みモデルはcopy-files.js
で/public
にコピーされます。
demo4onnxjs/copy-files.js at master · hayashikun/demo4onnxjs · GitHub
また、WebAssemblyを動かすために、onnx-wasm.wasm
などが必要なので、それも/public
にコピーしています。
手書き&認識
マウスで手書きするために、DrawableCanvas
というComponentを実装しました。
demo4onnxjs/DrawableCanvas.tsx at master · hayashikun/demo4onnxjs · GitHub
入力は28x28でないといけないので、非表示の28x28のCanvasに描いた内容をコピーして取得しています。
書き込んだ数字を認識させるためには、モデルを読み込む必要があります。
読み込むモデルとバックエンド、Opset Versionを選んだらLoadを押してモデルを読み込みます。
一部の組み合わせ、例えばNet2 & webgl & version 9
では、
Error: TypeError: cannot resolve operator 'LogSoftmax' with opsets: ai.onnx v9
のようなエラーが出ます。
これはLogSoftmax
がwebgl
のbackendでまだ実装されていないからです。
実はLogSoftmax
はまだcpu
でも実装されていないので、自分で実装しました。(PR投げた)
モデルを読み込んだら、Evalボタンを押せば結果が出力されます。
表示されるのは確率で、一番高いやつが赤字です。
正直あんまり認識は良くないです。
終わり
ちょっともうちょい見た目とか直したい。
実はPGGANでアイドルの顔写真生成とかも試しているので、それについても書きたい。
onnxjsにはまだまだOperatorが足りていないので、自分でもっと色々追加していきたいな〜