LLMにおける固定的な累積を、学習可能な、入力依存の深層方向アテンションに置き換える。
LLM(大規模言語モデル)における標準的な残差接続は、すべての層の出力を固定された重みで累積するため、制御不能な成長を引き起こし、各層の貢献を薄めてしまいます。Attention Residuals (AttnRes)は、この固定された累積を先行する層の出力に対するsoftmax attentionに置き換えることで、各層が入力に依存した学習可能な重みを用いて、以前の表現を選択的に集約することを可能にします。Block AttnResは、この手法を最小限のオーバーヘッドで大規模に適用することを可能にします。
標準的な残差接続は、現代のLLM(大規模言語モデル)の基盤です。更新ルールhl = hl-1 + f(hl-1)は、安定した学習を可能にする勾配ハイウェイを提供します。しかし、PreNorm(主流のパラダイム)では、この固定的な累積により、隠れ状態の大きさが深さとともにO(L)に比例して増加し、各層の相対的な寄与が徐々に希薄化されます。
AttnResは、以下の基本的な洞察に基づいています。残差ネットワークにおける深さ方向の累積は、形式的にはRNNにおける逐次的な再帰と双対である。Transformerが、固定された再帰をシーケンスの位置に対するアテンションで置き換えることでRNNを改善したように、AttnResは、固定された深さ方向の累積を、層の出力に対するアテンションで置き換えます。
深層ニューラルネットワークを、一連の処理ステップ(層)からなるチェーンと考えると、残差接続では、各層の出力が、次の層に渡される前に、その入力に加算されます。これは、情報を変更せずに層をスキップさせるショートカットのようなものです。これにより、学習がより安定しますが、元の設計では、すべての層の貢献を均等に扱います。これは、従業員の意見を、専門知識に関係なく均等な重みで平均化するようなものです。一方、AttnResでは、ネットワークがどの層が最も有用であるかを学習させることができます。これは、各従業員の意見に、現在の質問との関連性に基づいて重みを付けるようなものです。
固定された重みは、すべてのレイヤーの出力を均一に集約します。隠れ状態は、深さとともにO(L)に増加し、各レイヤーの寄与を薄めてしまいます。深さ方向の混合を調整するメカニズムはありません。
hl = hl-1 + f(hl-1)
すべての先行レイヤーの出力に対して、Softmax attentionを適用します。入力に依存する、学習可能な重みを使用し、疑似クエリによって実現されます。最適な性能を発揮しますが、メモリ使用量はO(Ld)です。
αl = softmax(φ(wl, kj))
レイヤーをブロックに分割し、ブロックレベルでの表現に注目します。メモリ使用量をO(L)からO(N)に削減します。最小限のオーバーヘッドで、既存のシステムに容易に導入できます。
O(N) memory, N << L
標準的な残差接続は、現代のLLM(大規模言語モデル)の事実上の基本構成要素です。更新式 hl = hl-1 + fl-1(hl-1) は、勾配ハイウェイを提供し、恒等写像を通じて勾配が変換をバイパスできるようにすることで、深層における安定した学習を可能にします。しかし、残差接続は、2つ目の、あまり議論されない役割も果たします。それは、各層の出力がどのように単一の、徐々に成長する隠れ状態に集約されるかを定義することです。
実際には、PreNormが主流なパラダイムとなっていますが、その加重されていない累積によって、隠れ状態の大きさは深さとともにO(L)に増加します。これにより、各層の相対的な寄与が徐々に希薄化されます。初期層の情報は埋もれてしまい、選択的に取り出すことができません。実験的に、著者らは、最初の層と最後の層がしばしば大きな影響力を持つ一方で、中間層はほとんど寄与していないことを観察しています。
この論文では、深層方向に沿った累積と、RNNにおける逐次的な再帰との間の形式的な二重性が観察されています。この二重性を基に、彼らはAttention Residuals (AttnRes)を提案しており、これは固定された累積をhl = Σ αl→j · vjに置き換えるものです。ここで、αは、学習された各層のクエリと、先行する層の出力との間の単一のドット積から計算される、softmaxの注意の重みです。
標準的なトレーニングにおいて、Full AttnResは、バックプロパゲーションのために必要なレイヤーの出力が既に保持されているため、無視できるオーバーヘッドしか発生しません。しかし、大規模なモデルでは、アクティベーションの再計算やパイプライン並列処理が一般的に使用されます。Block AttnResは、レイヤーをブロックに分割し、キャッシュベースのP2P通信と2段階の推論戦略を使用することで、この問題を解決します。
残差学習は、深層ニューラルネットワークの学習において非常に重要です。各層は、隠れ状態を次のように更新します。hl = hl-1 + fl-1(hl-1)。この再帰を展開すると、層 l の隠れ状態は、埋め込みベクトルと、それ以前のすべての層の出力の合計に等しくなります。hl = h1 + Σfi(hi)。恒等写像は、損失から任意の層への直接的な勾配経路を提供します。
しかし、固定された係数を持つユニットは、各層の寄与を均一に扱います。Highwayネットワークは、学習された要素ごとのゲートを使用して、この制約を緩和し、変換と同一の間の補間を行います。しかし、どちらのアプローチも基本的な制約を共有します。各層は、直前の入力 hl-1 のみを利用でき、これはすべての以前の出力を混在させた単一の圧縮された状態です。
これは、(1) 特定の初期層の特徴を個別に抽出しないこと、(2) 深い層から個々の初期層への直接的な勾配経路がないこと、そして (3) すべての先行計算が単一の状態ベクトルに圧縮されるという、表現のボトルネックがあることを意味します。
これらの制限は、シーケンスモデリングにおけるRNNのよく知られたボトルネックを反映しており、固定された逐次的な再帰が最終的にアテンションによって置き換えられました。この類似性が、本論文の核心的な提案の動機となっています。固定された深さ方向の累積を、アテンションベースの集約に置き換える。
RNN (Recurrent Neural Networks: 再帰型ニューラルネットワーク) は、シーケンスを一度に1ステップずつ処理し、以前のすべての情報を単一の隠れ状態に圧縮します。このボトルネックは、よく知られた制限事項でした。つまり、古い情報は「忘れ去られ」、新しい情報が到着します。 Transformerは、アテンションメカニズムを使用して、各位置がすべての以前の位置を参照できるようにすることで、この問題をシーケンスに対して解決しました。
この論文の重要な洞察は、残差接続も、時間ではなく深さ方向に同じボトルネックを持つということです。各層は、前のすべての層の出力の圧縮された合計しか見ることができず、これはRNNが圧縮された状態しか見ることができないのと同様です。AttnResは、同じ修正を適用します。つまり、各層が個別にすべての前の層を参照できるようにします。
重要な洞察は、時間と深さの間の双対性です。時間方向のRNNと同様に、残差接続は、深さ方向において、すべての過去の情報を単一の状態に圧縮します。系列モデリングにおいて、Transformerは、再帰をアテンションに置き換えることでRNNを改善し、各位置がすべての過去の位置を選択的に参照できるようにしました。AttnResは、この原則を深さの次元に適用します。
一般的な形式では、固定された累積和を次のように置き換えます。hl = Σ αl→j · vj。ここで、α は、Σα = 1 を満たす、層ごとの注意の重みです。シーケンス長(これは数百万に達する可能性があります)とは異なり、ネットワークの深さは通常はそれほど大きくありません(L < 1000)。そのため、深さ方向の O(L2) の注意機構は、計算上実行可能です。
注意係数は、カーネル関数 φ を用いて、αl→j = φ(ql, kj) として計算されます。 著者は、φ(q, k) = exp(qT RMSNorm(k)) を、softmax 正規化とともに採用しています。 クエリ q は、層ごとに特有の 学習可能なパラメータ (入力に依存しない) であり、これは並列計算を可能にするための意図的な設計上の選択です。
RMSNorm が φ の内部に組み込まれており、これにより、出力の大きさが大きい層が注意の重みに過度に影響を与えるのを防ぎます。各トークンについて、Full AttnRes は O(L2d) の計算量と O(Ld) のメモリ量を必要とします。深さがシーケンス長よりもはるかに小さいことから、このコストは比較的わずかです。
標準的な学習におけるオーバーヘッドゼロ: O(Ld) のメモリオーバーヘッドは、バックプロパゲーションのために保持されているアクティベーションと完全に一致します。また、疑似的なクエリの独立性により、任意のグループのレイヤーに対する注意の重みは、シーケンシャルなレイヤーの実行を待つことなく、並行して計算できます。
通常の注意機構(Transformerなど)では、クエリは現在の入力データから生成されます。一方、Full AttnResでは、クエリ wl は、学習可能なパラメータです。これは、モデルが学習中に獲得する、入力から派生しない固定ベクトルです。これは意図的な選択であり、異なるレイヤーの注意機構の重みを並行して計算できるようにするためです。なぜなら、それらは互いの結果に依存しないからです。トレードオフは、わずかに表現力が低下すること(クエリが特定の入力に適合しない)ですが、アブレーションスタディでは、このコストは小さいことが示されています。
Block AttnResは、L層を、各ブロックがS = L/N層からなるN個のブロックに分割します。各ブロック内では、層の出力を合計によって単一の表現に削減します。ブロック間で、N個のブロックレベルの表現とトークン埋め込みに対して、フルアテンションが適用されます。これにより、メモリの使用量はO(L)からO(N)に、計算量はO(L2)からO(N2)にそれぞれ削減されます。
ブロック数 N は、2つの極値の間を補間します。N = L の場合、Full AttnRes が得られます。N = 1 の場合、標準的な残差接続に帰着します。実際には、S = 4(ブロックあたり4層)の設定が、ほとんどの利点をもたらしつつ、オーバーヘッドを最小限に押さえます。
2段階の計算戦略により、効率的な推論が可能になります。フェーズ1では、バッチ処理されたクエリを使用して、キャッシュされたブロック表現に対して、すべてのS層のブロック間アテンションを同時に計算します。フェーズ2では、ブロック内アテンションを逐次的に計算し、オンラインsoftmaxを通じて、フェーズ1の結果と統合します。これにより、メモリアクセスコストをブロック全体で償却できます。
54人の従業員(レイヤー)が、9つの部署(ブロック)に組織された会社を経営していると想像してください。各従業員の概要レポートが必要です。
この戦略の利点は、第1段階にかかるコストが、ブロック内のすべてのレイヤーに分散されるため、各個々のレイヤーが負担するのは、ブロック間の処理にかかるコストのごく一部であることです。
小規模なトレーニングの場合、AttnResはごくわずかな計算オーバーヘッドと、追加のメモリ使用量をもたらします。一方、大規模な分散トレーニングでは、パイプライン並列処理が主要なインフラストラクチャ上の課題となります。AttnResを完全に活用するには、パイプラインの各ステージが、先行するすべてのステージのレイヤー出力を参照する必要がありますが、パイプライン並列処理では、これらの出力がローカルで利用できないという問題があります。
クロスステージキャッシュがこの問題を解決します。各物理ステージは、複数の仮想ステージを順番に処理するため、初期の仮想ステージで受信したブロックはローカルにキャッシュされ、再送信する必要がありません。これにより、トランジションごとのピークコストがO(C)からO(P)に削減され、V倍の改善となり、計算との完全なオーバーラップが可能になります。測定されたエンドツーエンドのオーバーヘッドは、4%未満です。
パイプライン並列処理 は、モデルを複数のGPUに分割し、各GPUがモデルのサブセット(レイヤー)を処理する手法です。データは、工場の組立ラインのように、それらのGPUを通過します。AttnResにおける課題は、各「ステーション」(GPU)が、前のステーションからの出力を知る必要があることです。これは、追加の通信を必要とします。クロスステージキャッシュは、既に送信された情報を記憶することで、これを軽減します。これにより、新しい情報のみを送信するだけで済みます。
メモリオーバーヘッドはごくわずかです。クロスステージキャッシュにより、各ブロックはすべての仮想ステージ全体でちょうど1回だけ保存されるため、標準的なレイヤーごとのアクティベーションキャッシュと比較して非常に小さいです。
この二段階の計算戦略は、FullとBlock AttnResの両方に適用されます。単純な実装では、すべてのレイヤーでアテンションを計算するため、ブロック表現全体を毎回完全に処理する必要があります。しかし、代わりに、フェーズ1では、ブロック内のすべてのSクエリをまとめて1回のパスで処理し、フェーズ2では、オンラインのsoftmaxマージを用いた、ブロック内での逐次的な参照を行います。
この設計により、Block AttnRes の各層あたりの総 I/O コストはわずか 5.5d (読み込み + 書き込み) であり、これは標準的な残差ブロックの 3d と比較して、さらに mHC の 34d と比較しても大幅に低い値です。また、フェーズ 1 は計算と部分的にオーバーラップさせることで、そのコストをさらに隠すことができます。
現代のGPUにおいて、ボトルネックは多くの場合、計算処理ではなく、メモリ帯域幅—つまり、データがメモリから読み込まれたり、メモリに書き込まれたりする速度です。 「I/Oコスト」は、各レイヤーが必要とするデータの総バイト数を測定します。 Block AttnResは、1レイヤーあたり5.5dを達成します(ここでdはモデルの次元で、通常は〜1024〜4096)。これは、3dのベースラインコストに非常に近く、mHCの34dよりもはるかに優れています。 モデルの次元dが4096の場合、これは各レイヤーが約22KBのデータを移動することになり、これはベースラインの12KBやmHCの139KBと比較して、大幅に小さい値です。
5つのモデルサイズ(194Mから528Mの活性パラメータ)について、それぞれ3つのバリエーションで学習を行いました。バリエーションは、PreNorm baseline、Full AttnRes、および約8ブロックを持つBlock AttnResです。すべてのバリエーションは、各サイズグループ内で同一のハイパーパラメータとデータを使用しており、これにより残差機構の効果のみを分離して評価することができます。
調整されたスケーリング曲線は、以下の結果を示しています。Baseline は L = 1.891 × C-0.057、Block AttnRes は L = 1.870 × C-0.058、そして Full AttnRes は L = 1.865 × C-0.057 という関係を示します。これら3つはすべて類似した傾きを持っていますが、AttnRes は常に低い損失値を達成しています。最も大きなスケールにおいて、Full AttnRes と Block AttnRes の差はわずか 0.001 に縮小されます。
スケーリング則とは、モデルの学習に費やした計算リソースの量と、モデルの性能がどれだけ向上するかという、予測可能な関係性を表すものです。式 L = a × Cb は、損失が計算リソースの累乗に反比例して減少することを示しています。AttnRes がより低い係数 'a' を達成した場合、これはモデルがすべての計算レベルでより優れた性能を発揮することを示し、つまり、同じ GPU リソースで 25% 以上の価値を得られることを意味します。
Kimi Linear 48Bのフル構成では、MoE(Mixture of Experts)を使用した27のTransformerブロック(54層)が使用され、合計480億パラメータ、有効パラメータは30億パラメータとなります。各ブロックには6層のAttnResが適用されており、合計9つのブロックが構成されています。このモデルは、4096トークンのコンテキストウィンドウで、1.4兆トークンで事前学習されています。
トレーニングの動態分析から、以下の3つの重要な利点が明らかになりました。(1) トレーニング全体を通して、検証損失が低い状態が維持され、特に減衰段階に入るとその差が大きくなります。(2) 深層ネットワークの各層で、出力の大きさが均一になり、PreNormによる影響が軽減されます。これにより、深層のレイヤーは、より大きな出力を学習する必要がなくなります。(3) 安定した勾配分布が実現され、初期のレイヤーにおける過剰な勾配の発生を防ぎます。
これらのグラフは、標準的な残差(residuals)における根本的な問題、すなわちPreNorm の希釈(dilution)を明らかにします。
これは、標準的な残差を使用すると、一番前の席の学生がどんどん大きな声で叫び、一番後ろの席の学生がささやいているような教室に似ています。AttnRes は、全員に同じ音量に調整されたマイクを与えます。
Block AttnResは、15のすべてのベンチマークで、ベースラインと同等またはそれ以上の性能を示しました。 特に、GPQA-Diamond (+7.5) やMath (+3.6) などの多段階推論タスク、およびHumanEval (+3.1) などのコード生成タスクにおいて、大幅な改善が見られます。 また、MMLUやHellaSwagなどの知識を必要とするベンチマークでも、わずかながら改善が見られました。
436Mモデルに関するアブレーション実験は、主要な設計上の選択を検証しています。すべてのバリエーションは、同一のハイパーパラメータと計算リソースを使用しており、各コンポーネントの貢献を個別に評価できるように設計されています。
固定された計算量(〜6.5 × 1019 FLOPs)下での制御されたアーキテクチャの探索により、AttnResが最適な深さと幅のトレードオフをどのように変化させるかが明らかになりました。BaselineとAttnResの両方とも、H/Lb ≈ 0.3で最適値に達しますが、AttnResはテストされた25種類の構成すべてで、損失が低くなるという結果を示しました。その改善幅は0.019から0.063の範囲でした。AttnResはより深く、狭いモデルを好む傾向があり、これはAttnResが追加の深さをより効果的に活用できることを示唆しています。
注目すべき点として、低い dmodel/Lb 比率は、より深く狭いネットワークに対応します。AttnResが深さを重視する傾向は、そのメカニズムと一致しています。より深いモデルは、アテンションが選択できる層の出力をより多く生成し、深さ方向の集約の表現力を高めます。ただし、より深いモデルは、一般的に推論レイテンシを増加させます。
学習された注意の重みを可視化することで、AttnResが過去の情報源に対してどのように注意を配っているかを理解することができます。各ヒートマップは、l番目の注意層またはMLP層(行)が、過去の情報源(列)に対してどのように注意を配っているかを示しており、注意層とMLP層はそれぞれ個別に表示されます。
残差接続は、固定された再帰によって、深さ方向に情報を伝播させます。これは、RNNが時間方向に情報を伝播させるのと同様です。この二重性は、より高度なバリエーションにも及んでいます。シーケンス側におけるデータ依存のゲートは、深さ側におけるHighwayネットワークに対応し、デルタ則はDDLに対応し、MRLAはゲート付き線形注意を反映しています。これらの手法はすべて、層を時間ステップとして扱うものであり、同じ代数構造を共有しています。
AttnResは、Transformersがシーケンス次元にもたらしたのと同様に、深さ次元にfull softmax attentionを導入することで、このアナロジーを完結させます。Block AttnResは、ブロックスパースアテンションに対応しており、表現力をある程度犠牲にすることで、計算効率を向上させています。
このセクションでは、洗練された理論的洞察が明らかになります。シーケンス処理のために発明されたすべての技術には、深さ処理のための直接的な対応関係が存在するのです。
この二重性は単なる比喩ではありません。数学的な形式は同一です。これは、シーケンスモデリングの将来的な改善が、直接的に深度の次元に適用できる可能性を示唆しています。
すべての残差接続のバリエーションは、深さ混合行列 M ∈ RL×L として統合することができます。ここで、Ml→j は、層 l が層 j の出力に割り当てる重みを表します。標準的な残差接続は、すべて 1 の下三角行列 M を生成します。Highway ネットワークは、ランク 1 の因子を生成します。AttnRes は、softmax 正規化を用いた、入力に依存する密な下三角行列 M を生成します。
この視点から見ると、既存の残差バリアントは、深さ軸方向の線形アテンションの具体例であることがわかります。展開された(m)HC重みは、数学的にゲート付き線形アテンション遷移と同等です。AttnResはさらに、フルソフトマックスアテンションを使用することで、より優れた正規化と、ソース層間のより明確な選択を可能にしています。
シーケンスと深さの間の双対性をヒントに、AttnResは、固定された均一な残差の累積を、学習可能な入力依存の深さ方向アテンションに置き換えます。この手法は、アブレーション実験、スケーリング則の実験、および1.4Tトークンで事前学習された480億パラメータの生産規模モデルへの統合を通じて検証されています。Block AttnResは、実用的なバリエーションとして登場し、最小限のオーバーヘッドでほとんどの利点をもたらします。