JavaScript (Node.js / Deno / on Browser) でもモデルの学習もやりやすくなってきました。
以前は推論だけならいける状態だったのですが、学習方面まで伸びてきたということです。
今のところ Node.js と on Browser では Tensorflow.js が GPU なしならほぼすべての機能が使えて、
モデルによっては GPU 付きで動く時も出てきて、Deno も WebGPU 付きでそのうち動くかも?みたいな状態です。
私は Node.js より Deno を開発に使う機会が増えてきていますが、深層学習したい場合は Node.js はわかりやすい気がしてます。
何もしてないのに勝手に動かなくなる深層学習では、Node.js のようにバージョンが明確に管理されているほうが良いかも知れません。
ちなみに公式には既にいい感じのサンプルはたくさんあります。
深層学習と言えば Python な昨今ですが、JavaScript で開発する利点は以下と思います。
Python だとすぐ依存性が壊れるので、依存性の問題が起きにくいのは大きいです。
あとフロントエンドに関わる技術をすべて JavaScript に統一できるのはロマンです。
なお以下のような欠点もあります。
- 情報が少ない
- Keras や Python の資産を使いにくい
- on Browser ではモデル保存がしにくい
- すこし遅い→将来に期待
情報量が少ないのは仕方ありません。
Keras のような気軽に使える仕組みがないのも欠点ですが、Tensorflow が強力なのでだいぶ使いやすくはなってます。
ブラウザだとモデル保存がしにくいので学習は Node.js を使うのが自然です。
巷の実行速度の情報はあまり当てにならないので、いい加減にベンチマークしてみると、real / user / sys は以下のようになりました。
MNIST をありがちな 2層の CNN で、しょぼ CPU で、10 epochs 学習させたときの実行時間です。
上掲の Node.js のサンプルのモデルと epochs を変えただけ。
- Keras 2.3.1 (Python 3.6.13): 14m / 53m / 53s
- Tensorflow.js 3.7.0 (Node.js 16.9.1, V8 9.3): 27m / 83m / 120s
- Tensorflow.js 3.7.0 (Node.js 16.2.0, V8 9.0): 28m / 81m / 155s
- Tensorflow.js 3.7.0 (Node.js 12.22.6, V8 7.8): 29m / 91m / 34s
Tensorflow.js は Node.js 内部の
V8 のバージョン によって速度が変わります。
9.0 から 9.3 までにもかなりパフォーマンスの最適化が入っているので、どうなるかと思いましたが、
思ったよりは変わりませんでした。
結論としてはまだまだですかねえ。
早いと思えば早いし、遅いと思えばやや遅い。
理由としては Node.js では Tensorflow の C++ adapter を呼んでいますが、そこに結構なコストが掛かるからだと思っています。
wasm に変わるとどうなるのでしょうか。
ちなみにこの結果は wasm は効いてないのですが、
効くようになればもっと早くなるかも?と書かれています。
なお現状では関数が不足していて動かなかったです。
Node.js での wasm はこのへんに書いてありますが、まだ完璧ではないようです。
最近 ONNX の推論が wasm で動くようになったので、先を越された Tensorflow.js もそのうちサポートするのではないかと思っています。
まあブラウザは WebGL で動くから十分早いんですけどね。
ちなみにいま最も参考になるベンチマーク結果は以下なのかな。
この情報とか、巷のいろいろな情報を総合的に見ると、現状では PyTorch の 3倍遅いくらいになりそう。
wasm が効いた時に
本当に 10倍早くなるなら、一気に JavaScript のほうが早くなりそうですが、どうなるでしょう。
JAX などもあるのでまだまだどうなるかわかりませんが、JavaScript が Python を置き換える時代に期待したいものです。
個人的な認識としては、すべてをサポートしている CPU 版は遅い。
Node.js 版は CPU 版の延長線上で実装されているので、ファイルシステム上の利点は得られるけど、そんなに早くない。
WebGL を使うと一番早いけど、HTTPS 経由の実装になる欠点がある。
ファイルシステム上で効率の良い学習をするには Headless GL などが必要で、そこが詰まっているのでなかなか進まない。
WebGPU も似たような立ち位置にあって、
Dawn なるものがあるので期待はできるかも知れないけど、特に何もプランはされていない。
Wasm を使うとすべてに対応できて早そうで期待できるけど、ネイティブ実装よりは遅そうだし、なかなか実装は進んでいない。
そんな感じに見える。
そういや MXNet は使ったことなかったなあ。