An Empirical Model of Large-Batch Training
OpenAIのLLMの論文がbatchsize関連の議論で頻繁に引用している論文なので見てみた。
理解が正しいかあんま自信ない。
実験的な話
critical batch sizeとは、「そのバッチサイズまでなら上げても効率的」というバッチサイズである。
「効率的」というのは、「バッチサイズを上げた分だけステップ数を減らしてもいい」という意味。合計の計算量が変わらないままバッチサイズを上げて計算効率(並列性)を上げることができる範囲ということ。
実験での主張
- 境界(超えたら性能が大きく劣化するところ)が割とはっきりしてる
- critical batch sizeとgradient noise scale(後述)は相関する
理論的な話
critical batch sizeはgradient noise scaleで予測可能だということを説明している。簡単に説明すると:
- batch sizeを上げることで、samplingによるgraidentの近似がtrue gradientに近づく
- その近づき具合がサチるのが、critical batch size
- 従って、gradient noise scaleがそれを決める
引用元でよく見る式のメモ
様々なバッチサイズでの学習を考える。
学習を行いある性能に至るためには、「ステップ数が $S\text{min}$ 以上」かつ「example数が $E\text{min}$ 以上」という条件がある(らしい)。
そして、別のステップ数 $S$ 及び example数 $E$ で同じ性能に至ろうと思うと、$S$ と $E$ 間は以下のような関係になる(らしい)。 $$ (S / S\text{min} - 1) = (E / E\text{min} - 1) ^{-1} $$
この、双曲線にフィットさせるた $B\text{crit} = E\text{min} / S_\text{min}$ がcritical batch sizeらしい。
感想
$B\text{crit} = E\text{min} / S_\text{min}$は結論だけ見るとあまりに当たり前の式すぎて草。
train lossしか登場せず、汎化誤差の話が一切考えられていないというのは重要な留意すべき点だと思った。ただ、確かにLLMの文脈だとあまり重要ではないのかもしれない。この論文自体はLMだけの論文じゃないんだけどね。