ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
クソでかモデルでGPUがメモリ足りない時のための手法。ZeRO-DPとZeRO-Rがある。
ZeRO-DP
DP=data parallel。optimizerのstate(momentumなど)、gradient、paramを分散して持てば良い。
- stateを分散して持つ場合、各担当ワーカーくんがparamを更新した後、paramをall-gatherすればいい。
- gradientも分散して持つ場合、NNをブロックに分け、各ブロックのbackwardが終わり次第、gradientをreduce-scatterしながら進めばいい。
- paramも分散して保つ場合、各ブロックの冒頭で持ち主ワーカーくんからparamをbcastしてもらえばいい。
ZeRO-R
モデルパラレルで同じようなことをする話。省略。
感想
optimizerのstateを分担して持てる、って言われた瞬間にどうやるか分かるぐらい簡単な話な訳だが、言われるまで自分でこれが出来るって気づいてなかったのが悔しい(そういう人は多そう)。まぁモデルのパラメタサイズが問題になるような状況に今まであんま遭遇してなかったからという言い訳で一つ。
メモリ使用量と通信時間をトレードオフする手法な訳だが、通信データ量の変化しか書かれていない。実際にはlatency項もあるはずなので通信回数も一瞬気にならなくもなかったけど、全然無視できそうだな。
論文の書き方が、「概要から詳細に入っていく」みたいな流れを異様に重んじて概要がめちゃくちゃ繰り返されまくっててクソ冗長。謎。