Mímisbrunnr知恵の泉

← ベイズ統計 一覧

🎓 レベル:発展 | 重要度:A(必須)

📎 前提:変分推論の考え方 | 関連:ギブスサンプリング

要点(BLUF)

1. 平均場近似:q を積に分解する

変分推論の考え方 で「qq を最適化して事後に近づける」と決めましたが、qq をどんな分布族にするか。最も使われるのが平均場(mean-field)近似——成分どうしが独立だと割り切り、積に分解します。

q(θ)=iqi(θi)q(\theta)=\prod_{i} q_i(\theta_i)

qiq_i の形は決め打ちせず、ELBO を最大にするよう関数として最適化します。すると次の美しい更新式が出ます。

2. CAVI の更新式:他の因子の期待値で書ける

ELBO を qjq_j について(他を固定して)最大化すると、最適因子は

 logqj(θj)=Eqj[logp(θ,D)]+const \boxed{\ \log q_j^*(\theta_j)=\mathbb E_{q_{-j}}\big[\log p(\theta,D)\big]+\text{const}\ }

θj\theta_j 以外を qjq_{-j} で期待値を取った対数同時分布」。これを j=1,,dj=1,\dots,d と順に回して ELBO を単調に押し上げるのが座標上昇変分推論(CAVI)です。ギブスサンプリング(ギブスサンプリング)が「完全条件付きからサンプル」だったのに対し、CAVI は「完全条件付きの期待値で因子を更新」——サンプリングを期待値に置き換えた決定的版だと見ると分かりやすい。共役モデルでは qjq_j^* が標準分布(正規・ガンマなど)になり、更新がパラメータの計算式で書けます。

3. コードで CAVI を回す:平均は正確、分散は過小

2次元ガウス N(μ,Σ)\mathcal N(\mu,\Sigma)ρ=0.8\rho=0.8)を目標に、平均場 q(x1)q(x2)q(x_1)q(x_2) で近似します。最適な各因子の分散は条件付き分散 1/Λii1/\Lambda_{ii}Λ=Σ1\Lambda=\Sigma^{-1} は精度行列)に固定され、平均だけ CAVI で更新されます。

import numpy as np

# 目標:2次元ガウス N(μ, Σ), ρ=0.8。平均場 q(x1)q(x2) で近似
mu = np.array([1.0, 2.0]); rho = 0.8
Sigma = np.array([[1.0, rho], [rho, 1.0]])
Lam = np.linalg.inv(Sigma)                          # 精度行列 Λ=Σ⁻¹

def kl_mf(m, s2):                                    # KL(q || p), q=N(m, diag(s2))
    return 0.5*(np.trace(Lam@np.diag(s2)) + (m-mu)@Lam@(m-mu)
                - 2 + np.log(np.linalg.det(Sigma)/np.prod(s2)))

m = np.array([0.0, 0.0])                             # 平均の初期値
s2 = 1/np.diag(Lam)                                  # 最適分散=条件付き分散(固定)
print(f"真の周辺分散={np.diag(Sigma)} / 平均場の分散={s2}")
print(f"{'iter':<6}{'m1':>8}{'m2':>8}{'KL(q||p)':>11}")
for it in range(21):
    if it in (0, 1, 2, 5, 10, 20):
        print(f"{it:<6}{m[0]:>8.4f}{m[1]:>8.4f}{kl_mf(m, s2):>11.5f}")
    for i in range(2):                               # 座標上昇:1成分ずつ平均を更新
        j = 1 - i
        m[i] = mu[i] - (1/Lam[i,i]) * Lam[i,j] * (m[j] - mu[j])
print(f"収束: m={m.round(3)}(真{mu}回収), 分散={s2.round(3)}(真{np.diag(Sigma)}より過小)")

出力:

真の周辺分散=[1. 1.] / 平均場の分散=[0.36 0.36]
iter        m1      m2   KL(q||p)
0       0.0000  0.0000    3.01083
1      -0.6000  0.7200    1.79083
2      -0.0240  1.1808    1.03511
5       0.7316  1.7853    0.54685
10      0.9712  1.9769    0.51124
20      0.9997  1.9997    0.51083
収束: m=[1. 2.](真[1. 2.]回収), 分散=[0.36 0.36](真[1. 1.]より過小)

出力の意味:CAVI を回すと KL が単調に下がり(3.010.513.01\to0.51)、平均 mm は真の (1,2)(1,2)正確に回収します。ところが分散は 0.360.36 で固定——真の周辺分散 1.01.01ρ2=0.361-\rho^2=0.36に過小評価されています。KL が 00 まで下がらず 0.510.51 で止まるのも、平均場が相関を表現できないためです。「位置は当てるが、自信過剰(区間が狭すぎ)」が平均場 VI の典型的な癖です。

4. なぜ分散を過小評価するのか

平均場は成分を独立と仮定するので、相関 ρ\rho がつくる「斜めに伸びた」事後を、軸に平行な楕円でしか近似できません。KL(qp)\mathrm{KL}(q\,\|\,p) の最小化は、pp がほぼゼロの所で qq が大きくならないよう罰するため(mode-seeking)、qq は相関の谷に沿って細くフィットし、各軸の分散は条件付き分散(1/Λii=1ρ21/\Lambda_{ii}=1-\rho^2)まで縮みます。結果、周辺分散より狭くなる。相関が強い(ρ1\rho\to1)ほど過小評価が激しくなります。

flowchart LR
  P["真の事後(相関で斜めに伸びた楕円)"] --> Q["平均場 q(軸平行・相関なし)"]
  Q --> R["平均は正確<br/>分散は過小(自信過剰)"]

この過小評価は VI の既知の限界です。緩和には、相関を一部残す構造化変分や、全共分散の正規 qq確率的変分推論と再パラメータ化 の手法で最適化)を使います。不確実性を厳密に出したいときは MCMC(第4章)。

⚠️ よくある誤解

関連ノート