見出し画像

JAXとはPythonで何ですか?

高速なNumPyとして使いこなすためのチュートリアル~


Google製のライブラリで、AutogradとXLAからなる、機械学習のための数値計算ライブラリ。簡単に言うと「自動微分に特化した、GPUやTPUに対応した高速なNumPy」。NumPyとほとんど同じ感覚で書くことができます。自動微分については解説が多いので、この記事では単なる高速なNumPyの部分を中心に書いていきます。

関連記事

GPU対応のNumPyという観点では、似たライブラリとしてPFN製のCuPyや、AnacondaがスポンサーとなっているNumbaもあります。

配列の初期化

最初はCPUに限定して書きます。JAXの導入はとてもシンプルで、あたかもNumPyのように使うことができます。

import jax.numpy as jnp

# NumPyではnp.arange(25, dtype=np.float32).reshape(5, 5)
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
print(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24.]]

JAXでのNumPy関数

.block_until_ready()

NumPy関数はnpをjnpに書き換えるだけ。ただし、JAXでは非同期処理で計算されるため、計算の最後に.block_until_ready()を追加します。

# NumPyではnp.dot(x, x.T)
x_gram = jnp.dot(x, x.T).block_until_ready()
print(x_gram)
[[  30.   80.  130.  180.  230.]
 [  80.  255.  430.  605.  780.]
 [ 130.  430.  730. 1030. 1330.]
 [ 180.  605. 1030. 1455. 1880.]
 [ 230.  780. 1330. 1880. 2430.]]

特に理由がなければ.block_until_ready()はJAXの計算の最後のみ入れればOKです。

y = x + 1
x_gram = jnp.dot(x, y.T).block_until_ready() # 最後だけブロッキングを入れればOK

詳細は下記へ

ref



この記事が気に入ったらサポートをしてみませんか?