Scaling Vision Transformers to 22 Billion Parameters
Google Research
Introduction
- LLMは10B〜540Bって感じだけど、Vision Transformerは4Bとかしか見たことないから頑張るわ
- ちょっと発散しないためとか工夫必要だったわ
- 性能良かったわ
Model Architecture
- GPT-J風のparallel layer
- QK normalization: Q, Kにlayer normかける。attentionがめっちゃhard気味になって発散してたからこれで解決。
- biasがあったりなかったり
- 224x224 image, 解像度14x14のパッチを16x16=256個
Training Infrastructure and Efficiency
- JAX, FLAX, Scenicというライブラリを使ったよ
- 他の論文で言ういわゆるtensor parallelismについて深堀りしている。例えばmegatronとかでは「bcastとかall-reduceしてから、計算」ってしてる。でも、よく考えると、行列積の計算とring reduce-scatterを重ね合わせることができる。確かに。
- MFU (model flops utilization) が54.9%だって。TPUよくわからんが。
Experiments
Training Details
- label hierarchyを潰してsigmoid cross entropy
- reciprocal squared root learning rate(てなに?)
- few-shot adaptationの性能を良くするためにheadに高いweight decayをかけるらしい、知らんかったけどまぁ確かにという感じがするね