1
/
5

【TECH BLOG】JAXによるスケーラブルな機械学習

はじめに

こんにちは、ZOZO NEXT ZOZO ResearchのSai Htaung Khamです。ZOZO NEXTは、ファッション領域におけるユーザーの課題を想像しテクノロジーの力で解決すること、より多くの人がファッションを楽しめる世界の創造を目指す企業です。

ZOZO NEXTでは多くのアルゴリズムを研究開発しており、その中でJAXというライブラリを使用しています。JAXは高性能な機械学習のために設計されたPythonのライブラリです。NumPyに似ていますが、より強力なライブラリであると考えることができます。NumPyとは異なり、JAXはマルチGPU、マルチTPU、そして機械学習の研究に非常に有用な自動微分(Autograd)をサポートしています。

JAXはNumPyのAPIのほとんどをミラーリングしているので、NumPyライブラリに慣れている人なら非常に導入しやすいです。Autogradを使えば、Pythonのネイティブ関数とNumPyの関数を自動的に微分できます。JAXの詳細な機能については、JAXの公式GitHubリポジトリを参照してください。

そもそも、なぜJAXなのか?

機械学習アルゴリズムを構築する際、多くのMLエンジニアはTensorflowやPyTorchといった信頼性の高いMLフレームワークを利用することでしょう。成熟したMLフレームワークには成熟したエコシステムがあり、本番環境への統合や保守が容易になるため、良い決断です。当研究所では、これらのフレームワークを用いて実装された多くのアルゴリズムを持っています。

しかし、いくつかのアルゴリズムはNumPyライブラリを用いて、純粋なPythonで実装されています。その中には例えば、研究者やMLエンジニアが社内用に設計した埋め込みアルゴリズムがあります。埋め込みアルゴリズムは類似商品を効率よく抽出できるため、商品推薦などに有用です。実装がPythonであるため、このアルゴリズムは計算の実行時間にボトルネックがあります。お気づきのように、フレームワークを使用しない場合、パラメータの更新やモデルのダンプ等も自前で実装する必要があります。そのため、新しいアイデアをすぐに試すことが難しく、なかなか前に進めません。また、ライブラリや学習プロセスもCPUデバイスに限定されるため、拡張性がありません。共有メモリアーキテクチャを利用してマルチプロセスでアルゴリズムを実行できましたが、GPUやTPUなどの複数のホストやデバイスで実行し、垂直方向・水平方向にスケールできる状態が望ましいです。

そこで、拡張性・保守性の高い別のフレームワークにプログラムを移植する方法を検討した結果、以下のような特徴を持つJAXを採用しました。

  1. Single Program Multiple Dataアーキテクチャによる水平方向のスケーラビリティ
  2. NumPyのAPIをミラーリング
  3. Pythonに対応
  4. Autogradのサポート
  5. エコシステムまでオープンソース化されている(FlaxやHaikuなど)

特に(2)の性質によってNumPyで書かれたアルゴリズムを効率よく移植できるという点が、既存の他のフレームワークにはない利点でした。

本記事を読むことで分かること

本記事では、実世界のデータを使った機械学習を、JAXライブラリで実行する方法について説明します。通常、機械学習の理論を学び問題を解くときには、理解を深めるために小さなデータを使用します。しかし、実世界のデータに応用するとデータ量、モデルを格納するメモリサイズ、学習と評価のスケーラビリティなど多くの困難に直面することになります。ありがたいことに、現代のクラウドコンピューティングの革命と価格設定により、スケーラブルな機械学習は誰でも利用できるようになりました。


典型的な機械学習プロジェクトはデータの準備からモデルのサービングまで多くのステージで構成されますが、本記事で取り扱うのはデータの準備とモデルの学習に当たる部分です。特にJAXライブラリのパフォーマンスと、クラウドコンピューティング上でのスケーラビリティを実現する方法について説明します。

データって本当に大きいの? いつ、どこで、どうやって処理するの?

データと質の高いデータ変換が機械学習プロジェクトの成功の中心であることは、すべてのMLエンジニアが理解していることです。現実の機械学習プロジェクトでは、1台のマシンでETL(抽出、変換、ロード)プロセスを行えるような量のデータを扱うことは稀です。当研究所では、Google CloudやAWSなどのクラウドコンピューティングリソースに広く依存しており、通常、クラウドストレージやクラウドデータウェアハウスを使用してデータを管理しています。

ボトルネックに要注意!

クラウドストレージは、1台のマシンに収まりきらない大量のデータを保存するのにとても役立ちます。しかし、モデルの学習に利用するためには、ストレージからデータを読み出す効率的な方法を見つける必要があります。多くのMLエンジニアがGPUデバイスを使った学習中に遭遇する問題の1つは、GPUデバイスが十分に活用し切れず、学習プロセスに必要以上の時間がかかってしまうことです。次のTensorFlowモデルのプロファイリング結果をご覧ください。

参照:[モデルのプロファイリング]

よく観察すると、ディスクからデータを取得している間、GPUデバイスはほとんどの時間、アイドル状態であることに気づかれると思います。一般的には、学習中はGPUデバイスをビジー状態にしたいものです。これは、データ入力パイプラインにボトルネックがあることを示しています。

TF Dataってヒーローなの?

データ入力パイプラインのボトルネックを解消するために、TF DataというTensorFlowが提供する便利なツールを利用することにします。従来の方法では、下図のようにディスクからデータを順次読み込んでいました。下図のMapは、正規化、画像補強などのデータの変換処理です。

参照:[モデルへのデータの順次取り込み]


しかし、この方法では学習処理にデータ転送待ちが発生し、GPUデバイスがアイドル状態になってしまうというボトルネックが発生しています。そこで下図のように読み込みとデータ変換を並列に行うことで、学習の待ち時間が少なくなります。

参照:[TF Data Pipelineで効率的なデータ変換]


TFデータパイプラインのコンポーネントは再利用可能です。トレーニングやサービングフェーズに適用できます。TF DataライブラリはホストCPUを利用してデータを並列に処理しているので、CPUの性能が高ければ高いほど、データの読み込みや前処理が高速になることを念頭に置いておくことが重要です。

データ前処理パイプラインとして、Apache BeamやTFX Transformを使用する方法もありますが、今回は説明しません。本記事では、TF DataとJAXを使用して、スケーラブルな機械学習を共有します。

処理を高速化してみよう!

効果的なデータ前処理パイプラインを手に入れたことで、モデルの学習と評価のステップに移行します。JAXの便利なライブラリにvmapとpmapがあります。本記事では、vmapとpmapを使用してマルチGPUデバイスでの学習処理を高速化します。

#vmapによるauto-vectorization
import numpy as np
import jax.numpy as jnp
import jax

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

x = np.arange(5)
w = np.array([3., 1., 3.])
batch_size = 10
xs = np.arange(5 * batch_size).reshape(-1, 5)
ws = np.stack([w] * batch_size)
print(f"The shape of the x and w : {xs.shape, ws.shape}")

print("Process each sample.")
for sample in xs:
    print(convolve(sample, w))

print("Auto-vectorization with vmap:")
print(jax.vmap(convolve)(xs, ws))
#vmap処理とサンプル単位処理の比較結果
The shape of the x and w : ((10, 5), (10, 3))
Process each sample.
[ 7. 14. 21.]
[42. 49. 56.]
[77. 84. 91.]
[112. 119. 126.]
[147. 154. 161.]
[182. 189. 196.]
[217. 224. 231.]
[252. 259. 266.]
[287. 294. 301.]
[322. 329. 336.]
Auto-vectorization with vmap:
[[  7.  14.  21.]
 [ 42.  49.  56.]
 [ 77.  84.  91.]
 [112. 119. 126.]
 [147. 154. 161.]
 [182. 189. 196.]
 [217. 224. 231.]
 [252. 259. 266.]
 [287. 294. 301.]
 [322. 329. 336.]]


まずはvmapに関して説明します。vmapはコードを変更することなく関数をベクトル化(auto-vectorization)するものです。auto-vectorizationにより、vmap APIで関数をラップする以外にコードを変更することなく処理を高速化できます。これは、特にバッチ処理の際に非常に便利です。vmapの機能はまだまだあるので、以下のリンクから確認してください。

Automatic Vectorization in JAX
Authors: Matteo Hessel In the previous section we discussed JIT compilation via the function. This notebook discusses another of JAX's transforms: vectorization via . Consider the following simple code that computes the convolution of two one-dimensional
https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html


pmapの使い方は、vmapとよく似ています。しかし、pmapはMPIのようなCollective operationを提供し、プログラムが複数のデバイス上で通信しデバイスをまたいで合計や平均などの演算「MapReduce」を実行できます。このAPIにより、プログラムはスケールアウトできます。

続きはこちら

株式会社ZOZOでは一緒に働く仲間を募集しています
同じタグの記事
今週のランキング
株式会社ZOZOからお誘い
この話題に共感したら、メンバーと話してみませんか?