PyTorch で開発中の JIT 機能
概要
https://github.com/pytorch/pytorch/tree/v0.4.0/torch/csrc/jit
PyTorch は NN 定義をコンパイルし Python のユーザーコードを通らずモデルを実行できる機能を開発中。実は開発は普通にオープンに GitHub で行われているし master にじゃんじゃん入っているので試せる。以下の 2 種類の両方が実装されつつある。
- 実際に実行して作成された計算グラフを IR に変換する
torch.jit.compile
- Python のコードを制御構文も含めて直接 IR に変換する
torch.jit.script
以下は PyTorch 0.4.0 で実行。
In [1]: import torch In [2]: torch.__version__ Out[2]: '0.4.0'
torch.jit.compile
まずは試す
In [4]: @torch.jit.compile(nderivs=0) ...: def f(x): ...: return x * 2.0 + 1.0 ...: In [5]: x = torch.rand(2, 3) In [6]: f(x) Out[6]: tensor([[ 2.5645, 1.5038, 2.3040], [ 1.6422, 2.9829, 1.8165]]) In [7]: f(x) clang: error: unsupported option '-fopenmp' clang: error: unsupported option '-fopenmp' warning: pytorch jit fuser failed to compile with openmp, trying without it... Out[7]: tensor([[ 2.5645, 1.5038, 2.3040], [ 1.6422, 2.9829, 1.8165]])
2 回目の実行であからさまにコンパイルされている。
速度の比較
In [29]: def g(x): ...: return x + 1.0 ...: In [30]: def f_nojit(x): ...: for _ in range(1000): ...: x = g(x) ...: return x ...: In [31]: @torch.jit.compile(nderivs=0) ...: def f_jit(x): ...: for _ in range(1000): ...: x = g(x) ...: return x ...: In [32]: %timeit f_nojit(x) 1000 loops, best of 3: 1.74 ms per loop In [33]: %timeit f_jit(x) The slowest run took 14983.91 times longer than the fastest. This could mean that an intermediate result is being cached. 10 loops, best of 3: 10.8 µs per loop In [34]: %timeit f_jit(x) The slowest run took 7.04 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 11 µs per loop
Python の関数呼び出しは非常にオーバーヘッドが大きい。敢えて Python の関数呼び出しを行いまくる書き方をした関数を JIT 有りと JIT 無しで比較。100 倍近く高速になっている。Numba とかで再帰呼び出しで書いたフィボナッチとかをコンパイルしても大体 100 倍ぐらい速くなるので、ちょうどこんなもんだと思う。
IR を見る
In [35]: @torch.jit.compile(nderivs=0) ...: def f(x): ...: for _ in range(5): ...: x = x * x ...: return x ...: In [36]: f(x) Out[36]: tensor([[ 3.8660e-04, 6.9221e-20, 1.1373e-06], [ 1.6299e-16, 7.5940e-01, 3.5456e-13]]) In [37]: f(x) Out[37]: tensor([[ 3.8660e-04, 6.9221e-20, 1.1373e-06], [ 1.6299e-16, 7.5940e-01, 3.5456e-13]]) In [38]: f.graph_for(x) Out[38]: graph(%0 : Float(2, 3)) { %6 : Float(2, 3) = prim::FusionGroup_0(%0) return (%6); } with prim::FusionGroup_0 = graph(%8 : Float(2, 3)) { %9 : Float(2, 3) = aten::mul(%8, %8) %7 : Float(2, 3) = aten::mul(%9, %9) %5 : Float(2, 3) = aten::mul(%7, %7) %3 : Float(2, 3) = aten::mul(%5, %5) %1 : Float(2, 3) = aten::mul(%3, %3) return (%1); }
トレースを取ってそこから IR を作るので、ループは展開されている。 graph_for
という関数名からわかる通り、入力の shape に応じてコンパイルが行われており、IR の中で shape が決まっている。
逆伝搬の IR を見る
デコレータで何階微分をコンパイルするか指定させるインターフェースは非常に使いにくい。
In [44]: x.requires_grad = True In [45]: @torch.jit.compile(nderivs=1) ...: def f(x): ...: for _ in range(5): ...: x = x * x ...: return x ...: In [46]: f(x).sum().backward() In [47]: f(x).sum().backward() In [48]: f.graph_for(x) Out[48]: graph(%0 : Float(2, 3) -------- stage 1 -------- %6 : Float(2, 3!)) { %23 : Float(2, 3), %24 : Float(2, 3), %25 : Float(2, 3), %26 : Float(2, 3), %27 : Float(2, 3) = prim::FusionGroup_0(%0) ---------------- stage 1 ---------------- %22 : Float(2, 3) = prim::FusionGroup_1(%0, %27, %26, %25, %6, %24) return (%23, %22); } with prim::FusionGroup_0 = graph(%8 : Float(2, 3)) { %9 : Float(2, 3) = aten::mul(%8, %8) %7 : Float(2, 3) = aten::mul(%9, %9) %5 : Float(2, 3) = aten::mul(%7, %7) %3 : Float(2, 3) = aten::mul(%5, %5) %1 : Float(2, 3) = aten::mul(%3, %3) return (%1, %3, %5, %7, %9); } with prim::FusionGroup_1 = graph(%4 : Float(2, 3) %11 : Float(2, 3) %18 : Float(2, 3) %25 : Float(2, 3) %31 : Float(2, 3!) %32 : Float(2, 3)) { %34 : Float(2, 3) = aten::mul(%31, %32) %33 : Float(2, 3) = aten::mul(%31, %32) %30 : Float(2, 3) = aten::add[alpha={1}](%33, %34) %27 : Float(2, 3) = aten::mul(%30, %25) %26 : Float(2, 3) = aten::mul(%30, %25) %23 : Float(2, 3) = aten::add[alpha={1}](%26, %27) %20 : Float(2, 3) = aten::mul(%23, %18) %19 : Float(2, 3) = aten::mul(%23, %18) %16 : Float(2, 3) = aten::add[alpha={1}](%19, %20) %13 : Float(2, 3) = aten::mul(%16, %11) %12 : Float(2, 3) = aten::mul(%16, %11) %9 : Float(2, 3) = aten::add[alpha={1}](%12, %13) %6 : Float(2, 3) = aten::mul(%9, %4) %5 : Float(2, 3) = aten::mul(%9, %4) %2 : Float(2, 3) = aten::add[alpha={1}](%5, %6) return (%2); }
stage は何階微分かに相当する。 https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir.h#L175
forward だけでコンパイルした時は、stage 0 の return は 1 つだけだったのに対し、今回は stage 0 の return で一杯値を返している。これは、関数の返り値だけでなく、逆伝搬で必要になる情報を覚えておいてもらうために出力している。逆に、stage 0 でこれらが出てこないのは、ちゃんと生存解析的なことが行われているから?(要調査)
ちなみに:print とかは消滅する
In [52]: @torch.jit.compile(nderivs=0) ...: def f(x): ...: for _ in range(3): ...: x = x * x ...: print('YOYO', x) ...: return x ...: In [53]: f(x) YOYO tensor([[ 0.6119, 0.0635, 0.4251], [ 0.1031, 0.9829, 0.1667]]) YOYO tensor([[ 0.3745, 0.0040, 0.1807], [ 0.0106, 0.9662, 0.0278]]) YOYO tensor([[ 0.1402, 0.0000, 0.0327], [ 0.0001, 0.9335, 0.0008]]) Out[53]: tensor([[ 0.1402, 0.0000, 0.0327], [ 0.0001, 0.9335, 0.0008]]) In [54]: f(x) Out[54]: tensor([[ 0.1402, 0.0000, 0.0327], [ 0.0001, 0.9335, 0.0008]])
トレースに print は出てこないので、コンパイルされた IR にも print は含まれず、コンパイル後は出力が行われなくなる。他にも、当然だが、内容によって分岐していたりループ回数が違ったりするコードをコンパイルすると正しく動作しなくなる。
torch.jit.script
試す&速度計測
In [1]: import torch In [2]: x = torch.rand(2, 3) In [3]: def g(x): ...: return x + 1.0 ...: In [4]: def f_nojit(x): ...: for _ in range(1000): ...: x = g(x) ...: return x ...: In [5]: f_jit2 = torch.jit.script(f_nojit) In [6]: f_jit2(x) Out[6]: tensor([[ 1000.0438, 1000.9478, 1000.7737], [ 1000.9696, 1000.5553, 1000.5251]]) In [7]: %timeit f_jit2(x) 100 loops, best of 3: 13.4 ms per loop
上の f_nojit
の 1 ms より、むしろ遅くなってる!?
何で遅いの?
In [8]: @torch.jit.compile(nderivs=0) ...: def f_jit(x): ...: for _ in range(1000): ...: x = g(x) ...: return x ...: In [10]: f_nojit Out[10]: <function __main__.f_nojit> In [11]: f_jit Out[11]: <torch._C.CompiledFunction at 0x10facfa40> In [12]: f_jit2 Out[12]: <torch._C.GraphExecutor at 0x10fd11f10>
torch.jit.compile
をかけた関数は、 torch._C.CompiledFunction
になり、コンパイルが行われる。一方、現状、 torch.jit.script
で出てくるのは torch._C.GraphExecutor
であり、IR をインタプリタ実行していると思われる。
IR を見てみる
In [13]: f_jit2.graph Out[13]: graph(%x : Dynamic) { %1 : Dynamic = prim::Constant[value={1000}]() %2 : Dynamic = prim::Constant[value={1}]() %7 : Dynamic = prim::Loop(%1, %2, %x) block0(%3 : Dynamic, %4 : Dynamic) { %5 : Dynamic = ^g()(%4) %6 : Dynamic = prim::Constant[value={1}]() -> (%6, %5) } return (%7); }
torch.jit.compile
で出てきた IR との違いが興味深い。
- shape が決まっていない状態で変換されるので、shape のところが
Dynamic
になっている。 - ループがループのまま IR で表現されている。
ちなみに:こっちだとなんと print できる
In [14]: def f(x): ...: for _ in range(3): ...: x = x * x ...: print(x) ...: return x ...: In [15]: f2 = torch.jit.script(f) In [16]: f2(x) 0.0019 0.8982 0.5986 0.9402 0.3084 0.2758 [ CPUFloatTensor{2,3} ] 3.6755e-06 8.0684e-01 3.5828e-01 8.8395e-01 9.5081e-02 7.6056e-02 [ CPUFloatTensor{2,3} ] 1.3509e-11 6.5100e-01 1.2837e-01 7.8137e-01 9.0403e-03 5.7845e-03 [ CPUFloatTensor{2,3} ] Out[16]: tensor([[ 1.3509e-11, 6.5100e-01, 1.2837e-01], [ 7.8137e-01, 9.0403e-03, 5.7845e-03]])
コンパイルした後も print
が無視されない。
In [17]: f2.graph Out[17]: graph(%x : Dynamic) { %1 : Dynamic = prim::Constant[value={3}]() %2 : Dynamic = prim::Constant[value={1}]() %7 : Dynamic = prim::Loop(%1, %2, %x) block0(%3 : Dynamic, %4 : Dynamic) { %5 : Dynamic = aten::mul(%4, %4) = prim::Print(%5) %6 : Dynamic = prim::Constant[value={1}]() -> (%6, %5) } return (%7); }
IR を見てみると、 prim::Print
という命令が普通に emit されている。確かに、デバッグとかでみんな使いたいよね。
ソースコードなど
ソースコードを解説するのは大変なので省略。大まかな構成だけ。
- 型一覧 https://github.com/pytorch/pytorch/blob/v0.4.0/torch/csrc/jit/type.h#L14-L18
- 命令一覧 https://github.com/pytorch/pytorch/blob/v0.4.0/torch/csrc/jit/interned_strings.h
- コンパイラ呼んでるとこ https://github.com/pytorch/pytorch/blob/v0.4.0/torch/csrc/jit/fusion_compiler.cpp#L655
- C++ のソースコード生成してるところ https://github.com/pytorch/pytorch/blob/v0.4.0/torch/csrc/jit/fusion_compiler.cpp#L60
- 最適化 https://github.com/pytorch/pytorch/tree/v0.4.0/torch/csrc/jit/passes