← 機械学習テキスト 一覧

🎓 レベル:標準 | 重要度:A(必須)

📎 前提:再帰型ニューラルネットワーク(その限界を克服する)

要点(BLUF)


1. なぜ注意機構が必要か:RNN の3つの限界

再帰型ニューラルネットワークは系列を一歩ずつ処理しますが、機械翻訳のような seq2seq(系列を入力して別の系列を出力する)課題で、次の限界に突き当たります。

(1) 固定長文脈ベクトルのボトルネック

素朴な seq2seq は、エンコーダが入力系列全体を1本の固定長ベクトル cc(最後の隠れ状態)に押し込め、デコーダはそれだけを頼りに出力を生成します。

問題は、入力が長くなるほど cc に詰め込む情報が増えて溢れること。50 単語の文も 5 単語の文も同じ次元のベクトルに圧縮するため、長文では文頭の情報が失われやすい。これが「文脈ベクトルのボトルネック」です。

(2) 長期依存の困難

勾配消失により、遠く離れた位置どうしの依存(例:主語と遠い述語の一致)を学習しにくい。情報は時刻ステップを1つずつ伝播するため、距離 nn だけ離れた2語をつなぐには nn 回の伝播が必要です。

(3) 逐次処理で並列化できない

隠れ状態 hth_tht1h_{t-1} に依存するため、t=1,2,t=1,2,\dots順番にしか計算できません。GPU は並列計算が得意なのに、RNN はその利点を活かせない。これが学習速度の致命的な制約になります。

要するに:RNN は「1本の細い管に情報を流す」ので、長文で詰まり(ボトルネック)、遠い関係が薄れ(長期依存)、順番待ちが発生する(並列化不可)。


2. 注意の発想:全部見て、重み付きで取り出す

ボトルネックの根本原因は「入力全体を1本のベクトルに潰すこと」です。ならば潰さなければよい

注意機構の発想はこうです。出力側の各位置(クエリ)が、入力の全位置を直接参照し、「今どこが関係するか」を表す重みを計算して、関係する位置の情報を加重和で取り出す。

最初にこれを実現したのが Bahdanau ら(2014)の加法注意です。デコーダの各ステップ ss ごとに、エンコーダの全隠れ状態 hih_i との整合スコアを計算します。

ei=vtanh(W1hi+W2s)e_i = \mathbf{v}^\top \tanh(W_1 h_i + W_2 s)

これを softmax で重み αi\alpha_i に変え、ステップごとに変わる文脈ベクトル c=iαihic = \sum_i \alpha_i h_i を作ります。

要するに:固定の cc を1本だけ持つのをやめ、「デコーダの今の状態に応じて、入力のどこを見るかを毎回作り直す」。これでボトルネックが消えます。

このあと **Luong ら(2015)**が、tanh\tanh と小さなネットワークを使う加法注意より、内積(ドット積)で類似度を測る方が単純かつ高速だと示しました。これがスケール化ドット積注意に直結します。


3. Query / Key / Value:検索の比喩

現代の注意は3つの役割に分けて定式化します。検索エンジンの比喩が分かりやすいです。

記号名前役割(検索の比喩)
QQクエリ (Query)何を探しているか」。問い合わせ
KKキー (Key)各項目の見出し」。クエリと突き合わせる対象
VVバリュー (Value)実際に取り出す中身」。重みに応じて混ぜる本体

検索では、問い合わせ(クエリ)を各文書の見出し(キー)と照合して関連度を測り、関連度の高い文書の中身(バリュー)を返します。注意機構はこれを「ハード(1件だけ返す)」ではなく「ソフト(全件を重み付きで混ぜる)」に行います。

具体的には、入力の各トークン埋め込み xx から、3つの学習可能な行列で射影して Q,K,VQ,K,V を作ります。

Q=XWQ,K=XWK,V=XWVQ = X W^Q,\quad K = X W^K,\quad V = X W^V

WQ,WK,WVW^Q, W^K, W^V が学習対象です。何をクエリ・キー・バリューとして使うかをモデルが学習する点が重要です。


4. スケール化ドット積注意の導出

4.1 全体の式

 Attention(Q,K,V)=softmax ⁣(QKdk)V \boxed{\ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}\!\left(\dfrac{QK^\top}{\sqrt{d_k}}\right)V\ }

ここで QRn×dkQ \in \mathbb{R}^{n \times d_k}KRm×dkK \in \mathbb{R}^{m \times d_k}VRm×dvV \in \mathbb{R}^{m \times d_v}nn はクエリ数、mm はキー/バリュー数、dkd_k はクエリ・キーの次元です。

4.2 ステップごとの意味

  1. 類似度(スコア)QKRn×mQK^\top \in \mathbb{R}^{n \times m}(i,j)(i,j) 成分は、ii 番目のクエリと jj 番目のキーの内積 qikjq_i \cdot k_j。内積が大きいほど方向が揃っており「関連が強い」。
  2. スケーリングdk\sqrt{d_k} で割る(理由は次節)。
  3. softmax:各クエリ行を確率分布に変換。αij=exp(qikj/dk)jexp(qikj/dk)\alpha_{ij} = \dfrac{\exp(q_i \cdot k_j / \sqrt{d_k})}{\sum_{j'} \exp(q_i \cdot k_{j'} / \sqrt{d_k})}。これが「クエリ ii がキー jj にどれだけ注目するか」の重み。
  4. 加重和jαijvj\sum_j \alpha_{ij} v_j。重み αij\alpha_{ij} でバリューを混ぜ、クエリ ii の出力ベクトルを得る。

要するに:「クエリとキーの内積で関連度を測り → softmax で配分率に変え → バリューをその率で混ぜる」。これを行列1発で全クエリ同時に計算します。

softmax の確率的な背景(指数で重みを作り、正規化して分布にする理由)は、各スコアを「大きいほど指数的に強調し、全体を 11 に正規化して確率分布にする」変換だと理解すれば十分です。


5. なぜ dk\sqrt{d_k} で割るのか(導出)

ここが本トピックの核心です。dk\sqrt{d_k}分散を一定に保つための正規化で、これがないと softmax が飽和して勾配が消えます。

5.1 分散が次元に比例して増える

クエリ qq とキー kk の各成分が、独立で平均 00・分散 11 と仮定します。内積は

qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i

各項 qikiq_i k_i は独立な2変数の積なので、平均は E[qiki]=E[qi]E[ki]=0\mathbb{E}[q_i k_i] = \mathbb{E}[q_i]\mathbb{E}[k_i] = 0、分散は

Var(qiki)=E[qi2ki2]0=E[qi2]E[ki2]=11=1\mathrm{Var}(q_i k_i) = \mathbb{E}[q_i^2 k_i^2] - 0 = \mathbb{E}[q_i^2]\,\mathbb{E}[k_i^2] = 1 \cdot 1 = 1

独立な dkd_k 個の和なので分散は加算され、

Var(qk)=i=1dkVar(qiki)=dk,σ(qk)=dk\mathrm{Var}(q \cdot k) = \sum_{i=1}^{d_k} \mathrm{Var}(q_i k_i) = d_k, \qquad \sigma(q \cdot k) = \sqrt{d_k}

つまり内積の標準偏差は dk\sqrt{d_k} に比例して大きくなる。例えば dk=64d_k = 64 なら σ=8\sigma = 8dk=512d_k = 512 なら σ22.6\sigma \approx 22.6。次元が増えるほど内積の値が大きくばらつきます。

5.2 大きな入力は softmax を飽和させ、勾配を消す

softmax は入力の差が大きいと、最大要素の出力がほぼ 11、他がほぼ 00極端に尖った分布になります。この飽和領域では softmax のヤコビアンが 00 に近づきます。softmax の勾配は出力 pip_i を用いて

pizj=pi(δijpj)\frac{\partial p_i}{\partial z_j} = p_i(\delta_{ij} - p_j)

と書け、pi1p_i \to 1pi0p_i \to 0 では pi(1pi)0p_i(1-p_i) \to 0、すなわち勾配がほぼ消失します。学習初期にこうなると、注意の重みがほぼ更新されず学習が進みません。

5.3 dk\sqrt{d_k} で割ると標準偏差が 1 に戻る

スコアを dk\sqrt{d_k} で割ると、

Var ⁣(qkdk)=1dkVar(qk)=dkdk=1\mathrm{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{1}{d_k}\,\mathrm{Var}(q \cdot k) = \frac{d_k}{d_k} = 1

次元 dkd_k に依らず分散が約 11 に保たれます。これで softmax は飽和しにくい穏当な領域に収まり、勾配が健全に流れます。

要するにdk\sqrt{d_k} は「次元が増えると内積のばらつきが dk\sqrt{d_k} 倍に膨らむ」のをちょうど打ち消す係数。割らないと softmax がカチカチに尖って勾配が死ぬので割る。深さ(次元)が変わってもスケールを揃える正規化です。


6. Q/K/V 計算フロー(図解)

flowchart LR
    X["入力埋め込み X"] --> Q["クエリ Q = X·WQ"]
    X --> K["キー K = X·WK"]
    X --> V["バリュー V = X·WV"]
    Q --> S["スコア = Q·Kᵀ(内積で類似度)"]
    K --> S
    S --> SC["スケーリング(÷ √dk)"]
    SC --> SM["softmax(重み αに正規化)"]
    SM --> W["加重和 Σ α·V"]
    V --> W
    W --> O["出力(文脈ベクトル)"]

自己注意では Q,K,VQ,K,V がすべて同じ入力 XX から作られる点に注目してください(次節)。


7. 自己注意と交差注意

注意は「クエリとキー/バリューがどこから来るか」で2種類に分かれます。

自己注意 (self-attention)

Q,K,VQ,K,Vすべて同じ系列から作られます(Q=XWQ, K=XWK, V=XWVQ=XW^Q,\ K=XW^K,\ V=XW^V、いずれも同じ XX)。各トークンが同じ文の他のトークンを参照し、文脈に応じた表現を得ます。

交差注意 (cross-attention)

クエリと、キー/バリューが別の系列から来ます。典型は機械翻訳のデコーダで、

これにより「出力の各語が、入力のどこに対応するか」を学習します。Bahdanau 注意(第2節)はまさに交差注意の原型でした。

クエリ QQキー・バリュー K,VK,V用途
自己注意同じ系列同じ系列系列内の文脈把握
交差注意系列A系列B系列間の対応付け(翻訳など)

8. 計算量 O(n2d)O(n^2 d) とその含意

系列長を nn、表現次元を dd とします(簡単のため dkdvdd_k \approx d_v \approx d)。

合わせて時間計算量は O(n2d)O(n^2 d)、注意行列のメモリは O(n2)O(n^2) です。

含意は2つあります。

  1. 並列化できる:RNN と違い、全位置のスコアを行列積で一度に計算できます。逐次依存がないので GPU で完全並列。これが学習速度の革命でした。
  2. 系列長の2乗で重いnn が長いと n2n^2 が爆発します。長文書や高解像度画像(nn が大)では計算・メモリが厳しく、これが疎な注意・線形注意・FlashAttention などの効率化研究の動機になっています。

要するに:RNN は「速くないが軽い(O(n)O(n) だが逐次)」、注意は「並列で速いが n2n^2 で重い」。距離に依らず任意の2位置を1ステップで結べる(最大経路長 O(1)O(1))のが、長期依存に効く本質です。


9. Transformer への橋渡し

スケール化ドット積注意そのものは学習可能なパラメータを持つ層ではなく、Q,K,VQ,K,V を与えると重み付き和を返す演算です。これを部品として、Transformerは次を積み上げます。

注意機構の理解は、そのまま Transformer と Phase 12 の大規模言語モデル(LLM)の土台になります。


⚠️ よくある誤解・落とし穴


対応するシミュレーション

simulations/attention_weights.py:スケール付きドット積注意 softmax(QK/d)V\mathrm{softmax}(QK^\top/\sqrt{d})V の重み行列をヒートマップで描き、各クエリが各キーをどれだけ見るか(行和=1)を可視化します。さらにキー次元 dd を変えて、d\sqrt{d} で割らないと内積が大きくなりすぎてソフトマックスが一点に偏り(エントロピーが0へ=飽和)勾配が消えること、d\sqrt{d} で割ると分布が保たれることを示します(Transformer)。

スケール付きドット積注意と√dの効果

関連ノート