PyTorch で開発中の JIT 機能

概要

https://github.com/pytorch/pytorch/tree/v0.4.0/torch/csrc/jit

PyTorch は NN 定義をコンパイルPython のユーザーコードを通らずモデルを実行できる機能を開発中。実は開発は普通にオープンに GitHub で行われているし master にじゃんじゃん入っているので試せる。以下の 2 種類の両方が実装されつつある。

  • 実際に実行して作成された計算グラフを IR に変換する torch.jit.compile
  • Python のコードを制御構文も含めて直接 IR に変換する torch.jit.script

以下は PyTorch 0.4.0 で実行。

In [1]: import torch

In [2]: torch.__version__
Out[2]: '0.4.0'

torch.jit.compile

まずは試す

In [4]: @torch.jit.compile(nderivs=0)
   ...: def f(x):
   ...:     return x * 2.0 + 1.0
   ...:

In [5]: x = torch.rand(2, 3)

In [6]: f(x)
Out[6]:
tensor([[ 2.5645,  1.5038,  2.3040],
        [ 1.6422,  2.9829,  1.8165]])

In [7]: f(x)
clang: error: unsupported option '-fopenmp'
clang: error: unsupported option '-fopenmp'
warning: pytorch jit fuser failed to compile with openmp, trying without it...
Out[7]:
tensor([[ 2.5645,  1.5038,  2.3040],
        [ 1.6422,  2.9829,  1.8165]])

2 回目の実行であからさまにコンパイルされている。

速度の比較

In [29]: def g(x):
    ...:     return x + 1.0
    ...:

In [30]: def f_nojit(x):
    ...:     for _ in range(1000):
    ...:         x = g(x)
    ...:     return x
    ...:

In [31]: @torch.jit.compile(nderivs=0)
    ...: def f_jit(x):
    ...:     for _ in range(1000):
    ...:         x = g(x)
    ...:     return x
    ...:

In [32]: %timeit f_nojit(x)
1000 loops, best of 3: 1.74 ms per loop

In [33]: %timeit f_jit(x)
The slowest run took 14983.91 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 10.8 µs per loop

In [34]: %timeit f_jit(x)
The slowest run took 7.04 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 11 µs per loop

Python の関数呼び出しは非常にオーバーヘッドが大きい。敢えて Python の関数呼び出しを行いまくる書き方をした関数を JIT 有りと JIT 無しで比較。100 倍近く高速になっている。Numba とかで再帰呼び出しで書いたフィボナッチとかをコンパイルしても大体 100 倍ぐらい速くなるので、ちょうどこんなもんだと思う。

IR を見る

In [35]: @torch.jit.compile(nderivs=0)
    ...: def f(x):
    ...:     for _ in range(5):
    ...:         x = x * x
    ...:     return x
    ...:

In [36]: f(x)
Out[36]:
tensor([[ 3.8660e-04,  6.9221e-20,  1.1373e-06],
        [ 1.6299e-16,  7.5940e-01,  3.5456e-13]])

In [37]: f(x)
Out[37]:
tensor([[ 3.8660e-04,  6.9221e-20,  1.1373e-06],
        [ 1.6299e-16,  7.5940e-01,  3.5456e-13]])

In [38]: f.graph_for(x)
Out[38]:
graph(%0 : Float(2, 3)) {
  %6 : Float(2, 3) = prim::FusionGroup_0(%0)
  return (%6);
}
with prim::FusionGroup_0 = graph(%8 : Float(2, 3)) {
  %9 : Float(2, 3) = aten::mul(%8, %8)
  %7 : Float(2, 3) = aten::mul(%9, %9)
  %5 : Float(2, 3) = aten::mul(%7, %7)
  %3 : Float(2, 3) = aten::mul(%5, %5)
  %1 : Float(2, 3) = aten::mul(%3, %3)
  return (%1);
}

トレースを取ってそこから IR を作るので、ループは展開されている。 graph_for という関数名からわかる通り、入力の shape に応じてコンパイルが行われており、IR の中で shape が決まっている。

逆伝搬の IR を見る

デコレータで何階微分コンパイルするか指定させるインターフェースは非常に使いにくい。

In [44]: x.requires_grad = True

In [45]: @torch.jit.compile(nderivs=1)
    ...: def f(x):
    ...:     for _ in range(5):
    ...:         x = x * x
    ...:     return x
    ...:

In [46]: f(x).sum().backward()

In [47]: f(x).sum().backward()

In [48]: f.graph_for(x)
Out[48]:
graph(%0 : Float(2, 3)
      -------- stage 1 --------
      %6 : Float(2, 3!)) {
  %23 : Float(2, 3), %24 : Float(2, 3), %25 : Float(2, 3), %26 : Float(2, 3), %27 : Float(2, 3) = prim::FusionGroup_0(%0)
  ---------------- stage 1 ----------------
  %22 : Float(2, 3) = prim::FusionGroup_1(%0, %27, %26, %25, %6, %24)
  return (%23, %22);
}
with prim::FusionGroup_0 = graph(%8 : Float(2, 3)) {
  %9 : Float(2, 3) = aten::mul(%8, %8)
  %7 : Float(2, 3) = aten::mul(%9, %9)
  %5 : Float(2, 3) = aten::mul(%7, %7)
  %3 : Float(2, 3) = aten::mul(%5, %5)
  %1 : Float(2, 3) = aten::mul(%3, %3)
  return (%1, %3, %5, %7, %9);
}
with prim::FusionGroup_1 = graph(%4 : Float(2, 3)
      %11 : Float(2, 3)
      %18 : Float(2, 3)
      %25 : Float(2, 3)
      %31 : Float(2, 3!)
      %32 : Float(2, 3)) {
  %34 : Float(2, 3) = aten::mul(%31, %32)
  %33 : Float(2, 3) = aten::mul(%31, %32)
  %30 : Float(2, 3) = aten::add[alpha={1}](%33, %34)
  %27 : Float(2, 3) = aten::mul(%30, %25)
  %26 : Float(2, 3) = aten::mul(%30, %25)
  %23 : Float(2, 3) = aten::add[alpha={1}](%26, %27)
  %20 : Float(2, 3) = aten::mul(%23, %18)
  %19 : Float(2, 3) = aten::mul(%23, %18)
  %16 : Float(2, 3) = aten::add[alpha={1}](%19, %20)
  %13 : Float(2, 3) = aten::mul(%16, %11)
  %12 : Float(2, 3) = aten::mul(%16, %11)
  %9 : Float(2, 3) = aten::add[alpha={1}](%12, %13)
  %6 : Float(2, 3) = aten::mul(%9, %4)
  %5 : Float(2, 3) = aten::mul(%9, %4)
  %2 : Float(2, 3) = aten::add[alpha={1}](%5, %6)
  return (%2);
}

stage は何階微分かに相当する。 https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir.h#L175

forward だけでコンパイルした時は、stage 0 の return は 1 つだけだったのに対し、今回は stage 0 の return で一杯値を返している。これは、関数の返り値だけでなく、逆伝搬で必要になる情報を覚えておいてもらうために出力している。逆に、stage 0 でこれらが出てこないのは、ちゃんと生存解析的なことが行われているから?(要調査)

ちなみに:print とかは消滅する

In [52]: @torch.jit.compile(nderivs=0)
    ...: def f(x):
    ...:     for _ in range(3):
    ...:         x = x * x
    ...:         print('YOYO', x)
    ...:     return x
    ...:

In [53]: f(x)
YOYO tensor([[ 0.6119,  0.0635,  0.4251],
        [ 0.1031,  0.9829,  0.1667]])
YOYO tensor([[ 0.3745,  0.0040,  0.1807],
        [ 0.0106,  0.9662,  0.0278]])
YOYO tensor([[ 0.1402,  0.0000,  0.0327],
        [ 0.0001,  0.9335,  0.0008]])
Out[53]:
tensor([[ 0.1402,  0.0000,  0.0327],
        [ 0.0001,  0.9335,  0.0008]])

In [54]: f(x)
Out[54]:
tensor([[ 0.1402,  0.0000,  0.0327],
        [ 0.0001,  0.9335,  0.0008]])

トレースに print は出てこないので、コンパイルされた IR にも print は含まれず、コンパイル後は出力が行われなくなる。他にも、当然だが、内容によって分岐していたりループ回数が違ったりするコードをコンパイルすると正しく動作しなくなる。

torch.jit.script

試す&速度計測

In [1]: import torch

In [2]: x = torch.rand(2, 3)

In [3]: def g(x):
   ...:     return x + 1.0
   ...:

In [4]: def f_nojit(x):
   ...:     for _ in range(1000):
   ...:         x = g(x)
   ...:     return x
   ...:

In [5]: f_jit2 = torch.jit.script(f_nojit)

In [6]: f_jit2(x)
Out[6]:
tensor([[ 1000.0438,  1000.9478,  1000.7737],
        [ 1000.9696,  1000.5553,  1000.5251]])

In [7]: %timeit f_jit2(x)
100 loops, best of 3: 13.4 ms per loop

上の f_nojit の 1 ms より、むしろ遅くなってる!?

何で遅いの?

In [8]: @torch.jit.compile(nderivs=0)
   ...: def f_jit(x):
   ...:     for _ in range(1000):
   ...:         x = g(x)
   ...:     return x
   ...:

In [10]: f_nojit
Out[10]: <function __main__.f_nojit>

In [11]: f_jit
Out[11]: <torch._C.CompiledFunction at 0x10facfa40>

In [12]: f_jit2
Out[12]: <torch._C.GraphExecutor at 0x10fd11f10>

torch.jit.compile をかけた関数は、 torch._C.CompiledFunction になり、コンパイルが行われる。一方、現状、 torch.jit.script で出てくるのは torch._C.GraphExecutor であり、IR をインタプリタ実行していると思われる。

IR を見てみる

In [13]: f_jit2.graph
Out[13]:
graph(%x : Dynamic) {
  %1 : Dynamic = prim::Constant[value={1000}]()
  %2 : Dynamic = prim::Constant[value={1}]()
  %7 : Dynamic = prim::Loop(%1, %2, %x)
    block0(%3 : Dynamic, %4 : Dynamic) {
      %5 : Dynamic = ^g()(%4)
      %6 : Dynamic = prim::Constant[value={1}]()
      -> (%6, %5)
    }
  return (%7);
}

torch.jit.compile で出てきた IR との違いが興味深い。

  • shape が決まっていない状態で変換されるので、shape のところが Dynamic になっている。
  • ループがループのまま IR で表現されている。

ちなみに:こっちだとなんと print できる

In [14]: def f(x):
    ...:     for _ in range(3):
    ...:         x = x * x
    ...:         print(x)
    ...:     return x
    ...:

In [15]: f2 = torch.jit.script(f)

In [16]: f2(x)
 0.0019  0.8982  0.5986
 0.9402  0.3084  0.2758
[ CPUFloatTensor{2,3} ]
 3.6755e-06  8.0684e-01  3.5828e-01
 8.8395e-01  9.5081e-02  7.6056e-02
[ CPUFloatTensor{2,3} ]
 1.3509e-11  6.5100e-01  1.2837e-01
 7.8137e-01  9.0403e-03  5.7845e-03
[ CPUFloatTensor{2,3} ]
Out[16]:
tensor([[ 1.3509e-11,  6.5100e-01,  1.2837e-01],
        [ 7.8137e-01,  9.0403e-03,  5.7845e-03]])

コンパイルした後も print が無視されない。

In [17]: f2.graph
Out[17]:
graph(%x : Dynamic) {
  %1 : Dynamic = prim::Constant[value={3}]()
  %2 : Dynamic = prim::Constant[value={1}]()
  %7 : Dynamic = prim::Loop(%1, %2, %x)
    block0(%3 : Dynamic, %4 : Dynamic) {
      %5 : Dynamic = aten::mul(%4, %4)
       = prim::Print(%5)
      %6 : Dynamic = prim::Constant[value={1}]()
      -> (%6, %5)
    }
  return (%7);
}

IR を見てみると、 prim::Print という命令が普通に emit されている。確かに、デバッグとかでみんな使いたいよね。

ソースコードなど

ソースコードを解説するのは大変なので省略。大まかな構成だけ。