Mímisbrunnr知恵の泉

← 因果推論 一覧

🎓 レベル:発展 | 重要度:B(標準)

📎 前提:MLをそのまま使うと因果を誤る理由 | 二重頑健推定AIPW | 数理:訓練・検証・テストと交差検証(機械学習)

要点(BLUF)


DML の 3 部品

MLをそのまま使うと因果を誤る理由で見た失敗(結果だけ残差化して生の処置に回帰)の修正は、たった一手「処置も残差化する」ことだった。それを正式化したのが DML(Chernozhukov et al., 2018)で、部品は 3 つ。

  1. 残差化(partialling-out):撹乱母数 (X)=E[YX]\ell(X)=E[Y\mid X]m(X)=E[TX]m(X)=E[T\mid X]任意の機械学習で予測し、両方を引いて残差 Y~,T~\tilde Y, \tilde T を作る。
  2. Neyman 直交モーメント:撹乱母数の誤差に一次で鈍感な推定式を使う。
  3. 交差適合(cross-fitting):撹乱母数の学習と残差の評価を別の標本で行い、過学習バイアス(自分自身を予測に使う相関)を断つ。
flowchart LR
    X["交絡 X"] -->|"ℓ(X)=E[Y∣X] を ML 予測"| RY["残差 Ỹ = Y − ℓ̂(X)"]
    X -->|"m(X)=E[T∣X] を ML 予測"| RT["残差 T̃ = T − m̂(X)"]
    RY --> TH["θ̂:Ỹ を T̃ に回帰"]
    RT --> TH

識別の仮定

DML は推定の技術であって、識別は別に要る。識別の仮定の条件付き交換可能性(XX で交絡が全て塞がる)・正値性・SUTVA が成り立つ前提で、部分線形モデル Y=θT+g(X)+ε,  T=m(X)+ηY=\theta T+g(X)+\varepsilon,\; T=m(X)+\etaθ\theta が因果効果として識別される。未観測交絡があれば DML でも因果は出ない。

Neyman 直交モーメントと推定量

部分線形モデルの直交スコアは

ψ(W;θ,η)=(Y(X)θ(Tm(X)))(Tm(X))\psi(W;\theta,\eta) = \Big(Y-\ell(X)-\theta\,\big(T-m(X)\big)\Big)\big(T-m(X)\big)

である。E[ψ]=0E[\psi]=0 を解くと、Y~=Y(X), T~=Tm(X)\tilde Y = Y-\ell(X),\ \tilde T = T-m(X) として

θ=E[Y~T~]E[T~2]\theta = \frac{E[\tilde Y\,\tilde T]}{E[\tilde T^{\,2}]}

これは両方の残差による回帰(Frisch–Waugh–Lovell)。撹乱母数で微分するとゼロ(直交)なので、^,m^\hat\ell,\hat m の誤差は二次でしか効かない。標本版は

θ^=iT~iY~iiT~i2,SE^=1n1niT~i2(Y~iθ^T~i)2(1niT~i2)2\hat\theta = \frac{\sum_i \tilde T_i\,\tilde Y_i}{\sum_i \tilde T_i^{\,2}}, \qquad \widehat{\mathrm{SE}} = \sqrt{\frac{1}{n}\cdot\frac{\frac1n\sum_i \tilde T_i^{2}\,(\tilde Y_i-\hat\theta\,\tilde T_i)^2}{\big(\frac1n\sum_i \tilde T_i^{2}\big)^2}}

分散は直交スコアの影響関数から来る。交差適合では、標本を KK 分割し、各 fold の残差をその fold を除いて学習した ^,m^\hat\ell,\hat m で作る。

コード:DML を手実装して真値を回収

MLをそのまま使うと因果を誤る理由と同じ高次元交絡データに、結果と処置の両方を交差適合の Lasso で残差化し、残差回帰で θ\theta と信頼区間を出す。

# === 結果Yと処置Tを各々XでML予測して残差化し、交差適合してθを推定 ===
import numpy as np
from sklearn.linear_model import LassoCV
from sklearn.model_selection import KFold

rng = np.random.default_rng(42)
ATE_true = 2.0
n, p = 500, 20
X = rng.standard_normal((n, p))
beta  = 1.0 / np.arange(1, p + 1)
gamma = 1.0 / np.arange(1, p + 1)
T = X @ gamma + rng.standard_normal(n)                 # 連続処置 T = m(X)+η
Y = ATE_true * T + X @ beta + rng.standard_normal(n)   # Y = θT + g(X)+ε

def crossfit_predict(X, target):
    # E[target|X] を交差適合(out-of-fold 予測)で当てる
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    pred = np.zeros(len(target))
    for idx_tr, idx_te in kf.split(X):
        model = LassoCV(cv=5, random_state=0).fit(X[idx_tr], target[idx_tr])
        pred[idx_te] = model.predict(X[idx_te])
    return pred

Y_resid = Y - crossfit_predict(X, Y)     # Ỹ = Y − Ê[Y|X]
T_resid = T - crossfit_predict(X, T)     # T̃ = T − Ê[T|X](処置も残差化=直交化)

theta = np.mean(T_resid * Y_resid) / np.mean(T_resid * T_resid)   # 残差回帰(FWL)
psi = T_resid * (Y_resid - theta * T_resid)                       # 直交スコア
se = np.sqrt(np.mean(psi ** 2) / np.mean(T_resid ** 2) ** 2 / n)
print("ATE_true     :", ATE_true)
print("DML 推定値   :", round(theta, 3))
print("標準誤差 SE  :", round(se, 3))
print("95%CI        : [", round(theta - 1.96 * se, 3), ",", round(theta + 1.96 * se, 3), "]")

出力は次のとおり。

ATE_true     : 2.0
DML 推定値   : 2.057
標準誤差 SE  : 0.044
95%CI        : [ 1.971 , 2.144 ]

MLをそのまま使うと因果を誤る理由の非直交プラグイン(0.89)と同じデータ・同じ Lasso を使っているのに、処置も残差化しただけで 2.057 と真値 2.0 に戻った。95%信頼区間 [1.971, 2.144] も真値を含む。漏れていた縮小バイアスが直交化で消えた。

コード:econml の LinearDML で再現(要最新確認)

同じ推定を econml(v0.16.0 で確認)の LinearDML でも再現する。econml は急速に動くライブラリなので API は要最新確認。未導入なら pip install econml手実装だけでも本ノートは完結するので、ライブラリは補助。

# === 同じデータに econml の LinearDML を当て、手実装と一致するか確認 ===
import numpy as np
from sklearn.linear_model import LassoCV
from econml.dml import LinearDML

rng = np.random.default_rng(42)
ATE_true = 2.0
n, p = 500, 20
X = rng.standard_normal((n, p))
gamma = 1.0 / np.arange(1, p + 1)
T = X @ gamma + rng.standard_normal(n)
Y = ATE_true * T + X @ (1.0 / np.arange(1, p + 1)) + rng.standard_normal(n)

est = LinearDML(model_y=LassoCV(cv=5), model_t=LassoCV(cv=5),
                discrete_treatment=False, cv=5, random_state=0)
est.fit(Y, T, X=None, W=X)               # X=効果修飾子(なし=定数効果), W=交絡
ate = float(est.ate())
lo, hi = est.ate_interval(alpha=0.05)
print("econml LinearDML ATE:", round(ate, 3))
print("95%CI               : [", round(float(lo), 3), ",", round(float(hi), 3), "]")

出力は次のとおり。

econml LinearDML ATE: 2.061
95%CI               : [ 1.974 , 2.148 ]

手実装(2.057)とほぼ一致する。econml では交絡を W、効果修飾子を X に渡す。X=None だと効果は定数(ATE)になる(異質効果は異質処置効果とメタ学習器(S/T/X-learner)X に修飾子を渡す)。

コード:信頼区間の被覆率が約 95% か

DML の売りは正しい推論。100 回反復して、95%CI が真値 2.0 を含む割合(被覆率)が名目の 95% 付近かを確かめる。

# === 100回反復し、95%CIが真値2.0を含む割合(被覆率)が約95%か確認 ===
import numpy as np
from sklearn.linear_model import LassoCV
from sklearn.model_selection import KFold

ATE_true = 2.0
n, p = 500, 20
gamma = 1.0 / np.arange(1, p + 1)
beta = 1.0 / np.arange(1, p + 1)

def crossfit_predict(X, target):
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    pred = np.zeros(len(target))
    for idx_tr, idx_te in kf.split(X):
        model = LassoCV(cv=3, random_state=0).fit(X[idx_tr], target[idx_tr])
        pred[idx_te] = model.predict(X[idx_te])
    return pred

covered = 0
for rep in range(100):
    rng = np.random.default_rng(2000 + rep)
    X = rng.standard_normal((n, p))
    T = X @ gamma + rng.standard_normal(n)
    Y = ATE_true * T + X @ beta + rng.standard_normal(n)
    Yr = Y - crossfit_predict(X, Y)
    Tr = T - crossfit_predict(X, T)
    th = np.mean(Tr * Yr) / np.mean(Tr * Tr)
    psi = Tr * (Yr - th * Tr)
    se = np.sqrt(np.mean(psi ** 2) / np.mean(Tr ** 2) ** 2 / n)
    if th - 1.96 * se <= ATE_true <= th + 1.96 * se:
        covered += 1
print("95%CI の被覆率 :", covered, "/ 100")

出力は次のとおり。

95%CI の被覆率 : 94 / 100

94/100 ≈ 95%。 点推定が当たるだけでなく信頼区間も妥当で、pp 値や CI を機械学習由来の撹乱母数の上で正当に使える。これが「非直交プラグイン」では達成できない DML の価値である。


直観:なぜ二手間(直交+交差適合)が効くか


⚠️ よくある誤解・落とし穴


関連ノート