Mímisbrunnr知恵の泉

← 因果推論 一覧

🎓 レベル:標準 | 重要度:A(必須)

📎 前提:二重頑健推定AIPW | 識別の仮定 | 数理:正則化(Ridge・Lasso・Elastic Net)(機械学習)

要点(BLUF)


なぜ「ML で交絡を調整」したいのか

交絡変数 XX が高次元(数十〜数千次元)だと、回帰による調整とその限界で見た線形回帰では関数形を取り違えやすい。そこで「E[YX]E[Y \mid X]E[TX]E[T \mid X] のような撹乱母数(nuisance function)を柔軟な機械学習で当て、その分を取り除けば交絡を消せるはず」という発想が自然に出てくる。

しかしここに落とし穴がある。機械学習の予測器をそのまま因果推定の式に差し込む(プラグインする)と、推定値が真の効果から系統的にずれる。 その理由を、構造を完全に把握できる擬似データで突き止める。

部分線形モデルと因果構造

処置 TT と結果 YY の関係を、次の**部分線形モデル(partially linear model)**で考える。

Y=θT+g(X)+ε,T=m(X)+ηY = \theta\, T + g(X) + \varepsilon, \qquad T = m(X) + \eta
flowchart LR
    X["高次元交絡 X(多数の共変量)"] --> T["処置 T(用量)"]
    X --> Y["結果 Y"]
    T -->|"効果 θ"| Y

XTX \to TXYX \to Y の両方が伸びているので、TTYY の素朴な比較にはバックドアパス TXYT \leftarrow X \to Y が混ざる。

識別の仮定(ここを満たして初めて θ は因果)

推定の前に、θ\theta が因果効果として識別できる条件を明示する(識別の仮定)。

この 3 つが成り立つ前提で初めて、θ\theta は構造方程式の係数=因果効果になる。以下で問題にするのは「識別できた θ\theta をどう推定するか」という推定の話であって、識別の失敗(未観測交絡)とは別の問題であることに注意する。


素朴な ML プラグインはなぜ外れるか

識別の仮定の下では、Frisch–Waugh–Lovell の関係から θ\theta両方を残差化した回帰係数として書ける。

θ=E[(YE[YX])(TE[TX])]E[(TE[TX])2]\theta = \frac{E\big[(Y - E[Y\mid X])(T - E[T\mid X])\big]}{E\big[(T - E[T\mid X])^2\big]}

ここで「結果だけ(X)=E[YX]\ell(X)=E[Y\mid X] で残差化し、生の TT に回帰する」という素朴なプラグインを考えると何が起きるか。(X)=θm(X)+g(X)\ell(X) = \theta\, m(X) + g(X) なので、^\hat\ell \to \ell と当てられたとして残差は

Y~=Y(X)=θ(Tm(X))+ε=θη+ε\tilde Y = Y - \ell(X) = \theta\big(T - m(X)\big) + \varepsilon = \theta\,\eta + \varepsilon

これを生の T=m(X)+ηT = m(X)+\eta に回帰すると、

θ^naive=E[Y~T]E[T2]=θVar(η)Var(m(X))+Var(η)  <  θ\hat\theta_{\text{naive}} = \frac{E[\tilde Y\, T]}{E[T^2]} = \theta\,\frac{\mathrm{Var}(\eta)}{\mathrm{Var}(m(X)) + \mathrm{Var}(\eta)} \;<\; \theta

処置の予測可能な分散 Var(m(X))\mathrm{Var}(m(X)) の分だけ、効果が縮小(減衰)する。 処置を残差化し損ねた m(X)m(X) が、そのまま θ^\hat\theta のバイアスとして「漏れた」のである。次のコードで、真値 2.0 が理論上 2×1Var(m)+10.772 \times \frac{1}{\mathrm{Var}(m)+1} \approx 0.77 まで潰れることを確かめる。

コード:高次元交絡で素朴な調整が外れる

下のコードは、20 次元の交絡を仕込んだ部分線形モデルから、(1) 無調整の回帰と (2) 「Lasso で結果だけ調整して生の処置に回帰」する非直交プラグインを計算する。ATE_true = 2.0 を当てられるかを見る。

# === 高次元交絡を仕込み、無調整と「非直交プラグイン」が外すことを確かめる ===
import numpy as np
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.model_selection import KFold

rng = np.random.default_rng(42)
ATE_true = 2.0
n, p = 500, 20
X = rng.standard_normal((n, p))
beta  = 1.0 / np.arange(1, p + 1)        # g(X) の係数(X→結果Y の交絡経路)
gamma = 1.0 / np.arange(1, p + 1)        # m(X) の係数(X→処置T の交絡経路)
m = X @ gamma
T = m + rng.standard_normal(n)                         # 連続処置(用量): T = m(X) + η
Y = ATE_true * T + X @ beta + rng.standard_normal(n)   # Y = θ·T + g(X) + ε

# (1) 無調整:Y を T だけに回帰(交絡を放置)
b_unadj = LinearRegression().fit(T.reshape(-1, 1), Y).coef_[0]

# (2) 非直交プラグイン:Lasso で E[Y|X] を当て、Y から引いてから「生の T」に回帰
kf = KFold(n_splits=5, shuffle=True, random_state=0)
y_hat = np.zeros(n)
for idx_tr, idx_te in kf.split(X):
    g_model = LassoCV(cv=5, random_state=0).fit(X[idx_tr], Y[idx_tr])
    y_hat[idx_te] = g_model.predict(X[idx_te])
Y_resid = Y - y_hat
b_plugin = np.mean(T * Y_resid) / np.mean(T * T)

var_m = gamma @ gamma                      # Var(m(X)) = Σ γ_j²(X は標準正規・独立)
shrink = 1.0 / (var_m + 1.0)               # 理論上の減衰率 Var(η)/Var(T)
print("ATE_true              :", ATE_true)
print("(1) 無調整 Y~T        :", round(b_unadj, 3))
print("(2) 非直交プラグイン  :", round(b_plugin, 3))
print("理論減衰率 Var(η)/Var(T):", round(shrink, 3),
      " → 予測値 θ×減衰 =", round(ATE_true * shrink, 3))

出力は次のとおり。

ATE_true              : 2.0
(1) 無調整 Y~T        : 2.636
(2) 非直交プラグイン  : 0.894
理論減衰率 Var(η)/Var(T): 0.385  → 予測値 θ×減衰 = 0.77

無調整は 2.64 と過大(交絡 XX が処置と結果を同方向に押すため上振れ)。一方、Lasso でちゃんと結果を調整したはずの非直交プラグインは 0.89 と、真値 2.0 を大きく下回る。理論減衰 0.77 とほぼ一致しており、ずれが偶然ではなく構造的なバイアスであることが分かる。「ML で予測してから引く」だけでは因果は当たらない。

コード:100 回反復で分布の中心を見る

1 回の値が偶然でないことを示すため、データ生成を 100 回繰り返し、各推定量の分布の中心が真値 2.0 に来るかを見る。あわせて、次ノートで扱う DML(処置も残差化して交差適合) を先取りで重ねる。

# === 100回反復して、各推定量の分布(中心が真値2.0か)を見る ===
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.model_selection import KFold

ATE_true = 2.0
n, p = 400, 20
beta = 1.0 / np.arange(1, p + 1)
gamma = 1.0 / np.arange(1, p + 1)

def generate(rng):
    X = rng.standard_normal((n, p))
    T = X @ gamma + rng.standard_normal(n)
    Y = ATE_true * T + X @ beta + rng.standard_normal(n)
    return X, T, Y

def crossfit_resid(X, Z):                  # E[Z|X] を交差適合で当てて残差を返す
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    pred = np.zeros(len(Z))
    for idx_tr, idx_te in kf.split(X):
        model = LassoCV(cv=3, random_state=0).fit(X[idx_tr], Z[idx_tr])
        pred[idx_te] = model.predict(X[idx_te])
    return Z - pred

results = {"無調整": [], "非直交プラグイン": [], "DML(直交+交差適合)": []}
for rep in range(100):
    rng = np.random.default_rng(1000 + rep)
    X, T, Y = generate(rng)
    results["無調整"].append(LinearRegression().fit(T.reshape(-1, 1), Y).coef_[0])
    Y_res = crossfit_resid(X, Y)
    results["非直交プラグイン"].append(np.mean(T * Y_res) / np.mean(T * T))
    T_res = crossfit_resid(X, T)           # 処置も残差化するのが直交化(次ノート)
    results["DML(直交+交差適合)"].append(np.mean(T_res * Y_res) / np.mean(T_res * T_res))

for name, vals in results.items():
    vals = np.array(vals)
    print(f"{name:18s} 平均={vals.mean():.3f}  バイアス={vals.mean()-ATE_true:+.3f}")

plt.figure(figsize=(8, 4))
for name, vals in results.items():
    plt.hist(vals, bins=20, alpha=0.6, label=name)
plt.axvline(ATE_true, color="black", linestyle="--", label="真値 ATE_true=2.0")
plt.xlabel("処置効果の推定値"); plt.ylabel("頻度")
plt.title("ML調整の素朴なプラグインは中心がずれる(100反復)")
plt.legend(); plt.tight_layout(); plt.show()

出力は次のとおり。

無調整                平均=2.621  バイアス=+0.621
非直交プラグイン           平均=0.812  バイアス=-1.188
DML(直交+交差適合)       平均=2.000  バイアス=+0.000

ヒストグラムでは、無調整は右に、非直交プラグインは左に外れ、どちらも真値 2.0 に中心が乗らない。一方で DML だけが 2.000(バイアス +0.000)で真値に重なる。「処置も残差化する」というたった一手の違いが、系統的バイアスを消すことを次ノートで掘り下げる。


直観:なぜ「直交性」が決定的か

撹乱母数 η=(,m)\eta = (\ell, m) を含む推定式(モーメント条件)ψ(W;θ,η)\psi(W;\theta,\eta)Neyman 直交であるとは、真値 (θ0,η0)(\theta_0,\eta_0) における η\eta 方向の微分(ガトー微分)がゼロであること、すなわち撹乱母数の推定誤差に対して一次で鈍感であることを言う。

つまり「結果も処置も残差化する」直交モーメントにすれば、二つの遅い誤差のになって無視できる。これが次ノートの DML の核心である。


⚠️ よくある誤解・落とし穴


関連ノート