2021年9月18日土曜日

Tensorflow.js でモデルを学習する

JavaScript (Node.js / Deno / on Browser) でもモデルの学習もやりやすくなってきました。 以前は推論だけならいける状態だったのですが、学習方面まで伸びてきたということです。 今のところ Node.js と on Browser では Tensorflow.js が GPU なしならほぼすべての機能が使えて、 モデルによっては GPU 付きで動く時も出てきて、Deno も WebGPU 付きでそのうち動くかも?みたいな状態です。

私は Node.js より Deno を開発に使う機会が増えてきていますが、深層学習したい場合は Node.js はわかりやすい気がしてます。 何もしてないのに勝手に動かなくなる深層学習では、Node.js のようにバージョンが明確に管理されているほうが良いかも知れません。 ちなみに公式には既にいい感じのサンプルはたくさんあります。

深層学習と言えば Python な昨今ですが、JavaScript で開発する利点は以下と思います。
  • 言語を統一できる
  • 依存性の問題が起きにくい


Python だとすぐ依存性が壊れるので、依存性の問題が起きにくいのは大きいです。 あとフロントエンドに関わる技術をすべて JavaScript に統一できるのはロマンです。 なお以下のような欠点もあります。
  • 情報が少ない
  • Keras や Python の資産を使いにくい
  • on Browser ではモデル保存がしにくい
  • すこし遅い→将来に期待


情報量が少ないのは仕方ありません。 Keras のような気軽に使える仕組みがないのも欠点ですが、Tensorflow が強力なのでだいぶ使いやすくはなってます。 ブラウザだとモデル保存がしにくいので学習は Node.js を使うのが自然です。

巷の実行速度の情報はあまり当てにならないので、いい加減にベンチマークしてみると、real / user / sys は以下のようになりました。 MNIST をありがちな 2層の CNN で、しょぼ CPU で、10 epochs 学習させたときの実行時間です。 上掲の Node.js のサンプルのモデルと epochs を変えただけ。

  • Keras 2.3.1 (Python 3.6.13): 14m / 53m / 53s
  • Tensorflow.js 3.7.0 (Node.js 16.9.1, V8 9.3): 27m / 83m / 120s
  • Tensorflow.js 3.7.0 (Node.js 16.2.0, V8 9.0): 28m / 81m / 155s
  • Tensorflow.js 3.7.0 (Node.js 12.22.6, V8 7.8): 29m / 91m / 34s


Tensorflow.js は Node.js 内部の V8 のバージョン によって速度が変わります。 9.0 から 9.3 までにもかなりパフォーマンスの最適化が入っているので、どうなるかと思いましたが、 思ったよりは変わりませんでした。 結論としてはまだまだですかねえ。 早いと思えば早いし、遅いと思えばやや遅い。 理由としては Node.js では Tensorflow の C++ adapter を呼んでいますが、そこに結構なコストが掛かるからだと思っています。 wasm に変わるとどうなるのでしょうか。 ちなみにこの結果は wasm は効いてないのですが、効くようになればもっと早くなるかも?と書かれています。 なお現状では関数が不足していて動かなかったです。 Node.js での wasm はこのへんに書いてありますが、まだ完璧ではないようです。 最近 ONNX の推論が wasm で動くようになったので、先を越された Tensorflow.js もそのうちサポートするのではないかと思っています。 まあブラウザは WebGL で動くから十分早いんですけどね。

ちなみにいま最も参考になるベンチマーク結果は以下なのかな。 この情報とか、巷のいろいろな情報を総合的に見ると、現状では PyTorch の 3倍遅いくらいになりそう。 wasm が効いた時に 本当に 10倍早くなるなら、一気に JavaScript のほうが早くなりそうですが、どうなるでしょう。 JAX などもあるのでまだまだどうなるかわかりませんが、JavaScript が Python を置き換える時代に期待したいものです。

個人的な認識としては、すべてをサポートしている CPU 版は遅い。 Node.js 版は CPU 版の延長線上で実装されているので、ファイルシステム上の利点は得られるけど、そんなに早くない。 WebGL を使うと一番早いけど、HTTPS 経由の実装になる欠点がある。 ファイルシステム上で効率の良い学習をするには Headless GL などが必要で、そこが詰まっているのでなかなか進まない。 WebGPU も似たような立ち位置にあって、Dawn なるものがあるので期待はできるかも知れないけど、特に何もプランはされていない。 Wasm を使うとすべてに対応できて早そうで期待できるけど、ネイティブ実装よりは遅そうだし、なかなか実装は進んでいない。 そんな感じに見える。

そういや MXNet は使ったことなかったなあ。

0 件のコメント: