Mímisbrunnr知恵の泉

← ベイズ統計 一覧

🎓 レベル:標準 | 重要度:A(必須)

📎 前提:ハミルトニアンモンテカルロとNUTS | 数理:MCMC(マルコフ連鎖モンテカルロ)(統計)

要点(BLUF)

1. なぜ診断するのか

MCMC は理論上は無限ステップで事後 π\pi に収束しますが、有限回では3つの問題が起こります。

これらを見抜く2大指標が R-hat と ESS です。

2・3. R-hat と ESS を実装し、良い連鎖と悪い連鎖を区別する

R-hat は、分散した初期値から mm 本のチェーン(各長さ nn)を回し、

W=1mjsj2 (チェーン内分散),B=nm1j(θˉjθˉ)2 (チェーン間分散)W=\frac1m\sum_{j}s_j^2\ (\text{チェーン内分散}),\quad B=\frac{n}{m-1}\sum_j(\bar\theta_j-\bar\theta)^2\ (\text{チェーン間分散})

から Var^=n1nW+1nB\widehat{\mathrm{Var}}=\frac{n-1}{n}W+\frac1n BR^=Var^/W\hat R=\sqrt{\widehat{\mathrm{Var}}/W}。チェーン間とチェーン内のばらつきが釣り合えば(収束していれば)R^1\hat R\to1ESS は自己相関 ρk\rho_k を使って ESSN/(1+2k1ρk)\mathrm{ESS}\approx N/(1+2\sum_{k\ge1}\rho_k)。両方を実装して、2つのケースで比べます。

import numpy as np

def rhat(chains):                       # chains: 形 (m, n)
    m, n = chains.shape
    W = chains.var(axis=1, ddof=1).mean()           # チェーン内分散
    B = n * chains.mean(axis=1).var(ddof=1)         # チェーン間分散
    var_hat = (n - 1)/n * W + B/n
    return np.sqrt(var_hat / W)

def ess(chains):                        # 自己相関ベースの簡易ESS
    x = (chains - chains.mean()).ravel(); N = len(x)
    acf = np.correlate(x, x, mode='full')[N-1:]; acf /= acf[0]
    s = 1.0
    for k in range(1, N):
        if acf[k] < 0: break                         # 初期正値まで和を取る
        s += 2*acf[k]
    return N / s

def mh_chains(logp, inits, step, n=5000, seed=0):
    rng = np.random.default_rng(seed); out = []
    for x0 in inits:
        x = x0; lp = logp(x); ch = np.empty(n)
        for i in range(n):
            cand = x + rng.normal(0, step); lc = logp(cand)
            if np.log(rng.uniform()) < lc - lp: x, lp = cand, lc
            ch[i] = x
        out.append(ch)
    return np.array(out)

# 良いケース:標準正規・分散した初期・適切な step → よく混合
good = mh_chains(lambda x: -0.5*x**2, inits=[-3,-1,1,3], step=2.0)[:, 1000:]
print(f"良いケース(標準正規):  R-hat={rhat(good):.3f}  ESS={ess(good):.0f}/{good.size}")

# 悪いケース:二峰 0.5N(-4,.5²)+0.5N(4,.5²)・各チェーンが片方の山に閉じ込め
def logp_bimodal(x):
    return np.logaddexp(-0.5*((x+4)/0.5)**2, -0.5*((x-4)/0.5)**2)
bad = mh_chains(logp_bimodal, inits=[-4,-4,4,4], step=0.5)[:, 1000:]
print(f"悪いケース(二峰・閉じ込め): R-hat={rhat(bad):.3f}  ESS={ess(bad):.0f}/{bad.size}")
print(f"  各チェーンの平均={bad.mean(axis=1).round(2)}(±4 に割れている=未収束)")

出力:

良いケース(標準正規):  R-hat=1.000  ESS=3326/16000
悪いケース(二峰・閉じ込め): R-hat=9.305  ESS=3/16000
  各チェーンの平均=[-3.94 -3.96  4.01  4.  ](±4 に割れている=未収束)

出力の意味:良いケースは R^=1.000\hat R=1.000(チェーン間と内のばらつきが一致=みな同じ分布を見ている)、ESS も 160001600033263326 と健全。悪いケースは R^=9.305\hat R=9.305一目で未収束——4本のチェーンが 4-4+4+4 の山に2本ずつ閉じ込められ、チェーン間分散 BB が爆発したためです。ESS はわずか 33(実質3個ぶんの情報しかない)。分散した初期値で複数チェーンを回し R^\hat R を見るだけで、この閉じ込めを検出できます。1本だけ見ていたら「片方の山できれいに収束した」と誤認したでしょう。

4. 実務の目安と対処

指標目安外れたときの対処
R^\hat R<1.01<1.01(厳しめ)。>1.1>1.1 は明確に未収束チェーン数・反復数を増やす、初期値を分散、再パラメータ化
ESS関心量ごとに数百以上(区間端なら多めに)反復を増やす、提案・サンプラー改善(HMC/NUTS)、間引きは基本不要
受容率RW-MH で 0.2〜0.5、HMC で 0.6〜0.9step(提案幅・ε\varepsilon)を調整(メトロポリスヘイスティングス
発散(divergence)0 が理想(NUTS)再パラメータ化(階層モデルの実例と再パラメータ化)、step 小さく

トレースプロットは最初の目視点検——よく混合した連鎖は水平な帯(毛虫状)、未収束はドリフトや帯の分離が見えます。良いケースと悪いケースのトレースを並べます。

import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib  # 日本語ラベル用

def mh_chains(logp, inits, step, n=5000, seed=0):
    rng = np.random.default_rng(seed); out = []
    for x0 in inits:
        x = x0; lp = logp(x); ch = np.empty(n)
        for i in range(n):
            cand = x + rng.normal(0, step); lc = logp(cand)
            if np.log(rng.uniform()) < lc - lp: x, lp = cand, lc
            ch[i] = x
        out.append(ch)
    return np.array(out)

good = mh_chains(lambda x: -0.5*x**2, [-3,-1,1,3], 2.0)
bad  = mh_chains(lambda x: np.logaddexp(-0.5*((x+4)/0.5)**2, -0.5*((x-4)/0.5)**2), [-4,-4,4,4], 0.5)
fig, ax = plt.subplots(1, 2, figsize=(11, 4), sharey=False)
for c in good: ax[0].plot(c[:1500], lw=0.6)
ax[0].set_title("良い:4本が重なり帯状(R-hat≈1.00)"); ax[0].set_xlabel("ステップ")
for c in bad: ax[1].plot(c[:1500], lw=0.6)
ax[1].set_title("悪い:±4 の山に分離(R-hat≈9.3)"); ax[1].set_xlabel("ステップ")
plt.tight_layout(); plt.show()

グラフの意味:左は4本のチェーンが同じ帯に重なり合い、よく混合(収束)。右は2本が下の山(4-4)、2本が上の山(+4+4)に貼り付いたまま行き来できず、帯がくっきり分離します。これが R^=9.3\hat R=9.3 の正体で、トレースを見れば一目で分かります。

⚠️ よくある誤解

まとめ(Phase 4)

第4章では、共役が崩れた事後からサンプルを生成する道具を揃えました——なぜサンプリングか(なぜサンプリングか)、規格化なしで引く MH(メトロポリスヘイスティングス)、条件付きで引くギブス(ギブスサンプリング)、勾配で滑空する HMC/NUTS(ハミルトニアンモンテカルロとNUTS)、そして出力を信じてよいか確かめる収束診断(本ノート)。得たサンプルは第3章の要約(事後分布の要約)にそのまま渡せます。次章では、この計算力を前提に、グループ構造を持つデータをまとめて扱う階層ベイズモデルへ進みます。

関連ノート