Mímisbrunnr知恵の泉

← ベイズ統計 一覧

🎓 レベル:発展 | 重要度:B(標準)

📎 前提:情報量規準WAICとDIC | 関連:訓練・検証・テストと交差検証(機械学習)

要点(BLUF)

1. 交差検証で予測性能を直接測る

情報量規準(情報量規準WAICとDIC)は予測性能を「当てはまり − 複雑さの罰」で近似しました。交差検証は、実際に一部を抜いて残りで学習し、抜いた分を予測して、予測性能を直接測ります。ベイズでは予測の良さを対数予測密度で測ります。

elpdLOO=ilogp(yiyi)=ilogp(yiθ)p(θyi)dθ\mathrm{elpd}_{\mathrm{LOO}}=\sum_{i}\log p(y_i\mid y_{-i})=\sum_i\log\int p(y_i\mid\theta)\,p(\theta\mid y_{-i})\,d\theta

各点 ii を、それ以外 yiy_{-i} で学習した事後で予測した対数密度の和。大きいほど予測が良い。2elpdLOO-2\,\mathrm{elpd}_{\mathrm{LOO}}(LOO-IC)が WAIC と同じスケールで比較できます。

2. 再フィットの重さと、重点サンプリングの回避策

LOO を定義どおりやると、各点を抜いた事後 p(θyi)p(\theta\mid y_{-i})NN 個求める=NN 回の再フィットで、MCMC では現実的でありません。そこで重点サンプリング(importance sampling):全データの事後 p(θy)p(\theta\mid y) から取ったサンプルを、yiy_i を抜いた事後へ重み付けで流用します。重みは

wis=1p(yiθs)p(yiyi)swisp(yiθs)swis=11Ss1/p(yiθs)w_{is}=\frac{1}{p(y_i\mid\theta_s)}\quad\Longrightarrow\quad p(y_i\mid y_{-i})\approx\frac{\sum_s w_{is}\,p(y_i\mid\theta_s)}{\sum_s w_{is}}=\frac{1}{\frac1S\sum_s 1/p(y_i\mid\theta_s)}

yiy_i をうまく当てるサンプルほど、yiy_i を抜いた事後では割り引く」という補正。全データの事後を1回計算するだけで全点の LOO が出ます。

3. コード:厳密 LOO = IS-LOO = WAIC

ベイズ線形回帰(2次)で、(a) 厳密 LOO(毎回再フィット)、(b) IS-LOO、(c) WAIC が一致することを確かめます。

import numpy as np
from scipy import stats

rng = np.random.default_rng(0)
n = 30; sigma = 0.5; alpha = 0.5
x = np.sort(rng.uniform(-1, 1, n))
y = 1.0 + 0.5*x - 1.5*x**2 + rng.normal(0, sigma, n)
Phi = np.vander(x, 3, increasing=True); beta = 1/sigma**2

def posterior(Ph, yy):
    S_N = np.linalg.inv(alpha*np.eye(Ph.shape[1]) + beta*Ph.T@Ph)
    return beta*S_N@Ph.T@yy, S_N

# (a) 厳密LOO:1点ずつ抜いて再フィット、抜いた点の対数予測密度
elpd_exact = 0.0
for i in range(n):
    idx = np.arange(n) != i
    m, S = posterior(Phi[idx], y[idx])
    pm, pv = Phi[i]@m, sigma**2 + Phi[i]@S@Phi[i]         # LOO予測 N(pm, pv)
    elpd_exact += stats.norm(pm, np.sqrt(pv)).logpdf(y[i])

# (b) IS-LOO:全データ事後から重み 1/p(y_i|θ_s) で補正(再フィットなし)
m_N, S_N = posterior(Phi, y)
W = rng.multivariate_normal(m_N, S_N, 8000)
p_is = stats.norm(W@Phi.T, sigma).pdf(y[None,:])          # p(y_i|θ_s)  (S,n)
elpd_isloo = np.sum(-np.log(np.mean(1/p_is, axis=0)))

# (c) WAIC
ll = np.log(p_is)
elpd_waic = np.sum(np.log(np.mean(p_is, axis=0))) - np.sum(np.var(ll, axis=0, ddof=1))

print(f"elpd(対数予測密度の和・大きいほど良い):")
print(f"  厳密LOO(再フィット)   = {elpd_exact:.3f}")
print(f"  IS-LOO(重点サンプリング)= {elpd_isloo:.3f}")
print(f"  WAIC近似               = {elpd_waic:.3f}")
print(f"  LOO-IC = {-2*elpd_isloo:.2f}   (WAIC = {-2*elpd_waic:.2f})")

出力:

elpd(対数予測密度の和・大きいほど良い):
  厳密LOO(再フィット)   = -19.636
  IS-LOO(重点サンプリング)= -19.618
  WAIC近似               = -19.599
  LOO-IC = 39.24   (WAIC = 39.20)

出力の意味:3つの elpd が 19.6-19.6 付近でぴたりと一致。IS-LOO は再フィットなし(全データ事後1回+重み付け)で、3030 回再フィットした厳密 LOO を再現しました。WAIC も漸近的に LOO に一致するため、ほぼ同じ値(LOO-IC 39.2439.24 と WAIC 39.2039.20)。だから実務では、計算の軽い IS-LOO や WAIC で予測性能を比較できます。

4. PSIS-LOO と k 診断

素朴な IS-LOO の重み 1/p(yiθs)1/p(y_i\mid\theta_s) は、外れ値や影響の強い点で裾が重く不安定になります。これを PSIS(Pareto smoothed importance sampling) が安定化します——大きい重みに一般化パレート分布を当てて平滑化し、同時にその形状パラメータ k^\hat k を診断に使います。

PSIS-LOO は WAIC より頑健で、外れ値・影響点を k^\hat k自動的に警告してくれる利点があります。実務では ArviZ の az.loo(PSIS-LOO)が定番です(API・既定は更新が速く要最新確認確率的プログラミング概観)。

⚠️ よくある誤解

関連ノート