Tae Hyun Kim (Lowell)

X-Learner

5분 읽기 #causal-inference#hte#meta-learner

Definition

X-Learner는 imputed treatment effect를 활용한 3단계 알고리즘으로, 그룹 간 불균형과 CATE의 구조적 특성을 효과적으로 활용하는 Meta-learners.

Algorithm:

Stage 1: Response functions 추정 μ^0(x)=E^[YX=x,W=0]\hat{\mu}_0(x) = \hat{E}[Y | X = x, W = 0] μ^1(x)=E^[YX=x,W=1]\hat{\mu}_1(x) = \hat{E}[Y | X = x, W = 1]

Stage 2: Imputed treatment effects 계산 및 CATE 추정 D~1i:=Y1iμ^0(X1i)(treatment group)\tilde{D}_{1i} := Y_{1i} - \hat{\mu}_0(X_{1i}) \quad \text{(treatment group)} D~0i:=μ^1(X0i)Y0i(control group)\tilde{D}_{0i} := \hat{\mu}_1(X_{0i}) - Y_{0i} \quad \text{(control group)}

각 그룹에서 τ(x)\tau(x) 추정:

  • τ^1(x)\hat{\tau}_1(x): Treatment group의 imputed effects로 학습
  • τ^0(x)\hat{\tau}_0(x): Control group의 imputed effects로 학습

Stage 3: Weighted combination τ^X(x)=g(x)τ^0(x)+(1g(x))τ^1(x)\hat{\tau}_X(x) = g(x)\hat{\tau}_0(x) + (1 - g(x))\hat{\tau}_1(x)

여기서 g(x)[0,1]g(x) \in [0, 1]은 weight function (보통 propensity score e^(x)\hat{e}(x) 사용).

Intuitive Understanding

핵심 아이디어:

관측된 outcome과 상대 그룹의 예측값을 이용해 “imputed” treatment effect를 만들고, 각 그룹의 관점에서 CATE를 추정한 후 결합

Stage 1: Estimate response functions (like T-learner)
           μ̂₀(x), μ̂₁(x)

Stage 2: Impute treatment effects
  Treatment group: D̃₁ᵢ = Y₁ᵢ - μ̂₀(X₁ᵢ)  (observed - predicted control)
  Control group:   D̃₀ᵢ = μ̂₁(X₀ᵢ) - Y₀ᵢ  (predicted treatment - observed)

         Train τ̂₁(x) on D̃₁, τ̂₀(x) on D̃₀

Stage 3: Weighted combination
         τ̂(x) = g(x)·τ̂₀(x) + (1-g(x))·τ̂₁(x)

왜 “X”인가?

  • Treatment group의 정보가 control group의 CATE 추정에 사용되고, vice versa
  • 정보가 “교차(cross)“하여 전달됨

Key Properties

Exploits CATE Structure

  • CATE가 단순할 때 (예: linear), 이 구조를 활용 가능
  • Response function이 복잡해도 CATE가 단순하면 빠른 rate 달성

Handles Imbalanced Groups

  • Weight function g(x)g(x)로 그룹 크기 불균형 조절
  • 큰 그룹의 정보를 더 많이 활용

Convergence Rates (Conjecture 1)

조건 하에서:

  • τ^0\hat{\tau}_0: O(maτ+naμ)O(m^{-a_\tau} + n^{-a_\mu})
  • τ^1\hat{\tau}_1: O(maμ+naτ)O(m^{-a_\mu} + n^{-a_\tau})

여기서:

  • aμa_\mu: Response function의 smoothness
  • aτa_\tau: CATE function의 smoothness
  • mm: Control group size, nn: Treatment group size

Parametric Rate (Theorem 2)

조건:

  • Linear CATE: τ(x)=xTβ\tau(x) = x^T \beta
  • Lipschitz response functions
  • mc3n1/am \geq c_3 n^{1/a}

결과: X-learner with g0g \equiv 0 achieves parametric rate O(n1)O(n^{-1}) in treatment group size.

Algorithm Detail

def x_learner(X, W, Y, base_learner, propensity_model=None):
    # Split data by treatment
    X_ctrl, Y_ctrl = X[W == 0], Y[W == 0]
    X_treat, Y_treat = X[W == 1], Y[W == 1]

    # Stage 1: Estimate response functions
    model_0 = base_learner.fit(X_ctrl, Y_ctrl)
    model_1 = base_learner.fit(X_treat, Y_treat)

    # Stage 2: Compute imputed treatment effects
    # For treatment group: observed - predicted control
    D_tilde_1 = Y_treat - model_0.predict(X_treat)
    # For control group: predicted treatment - observed
    D_tilde_0 = model_1.predict(X_ctrl) - Y_ctrl

    # Fit CATE models on imputed effects
    tau_model_1 = base_learner.fit(X_treat, D_tilde_1)
    tau_model_0 = base_learner.fit(X_ctrl, D_tilde_0)

    # Stage 3: Estimate propensity score for weighting
    if propensity_model is None:
        # Simple estimate or use provided
        g = lambda x: len(X_ctrl) / len(X)
    else:
        propensity_model.fit(X, W)
        g = lambda x: 1 - propensity_model.predict_proba(x)[:, 1]

    # Weighted combination
    def predict_cate(X_new):
        tau_0 = tau_model_0.predict(X_new)
        tau_1 = tau_model_1.predict(X_new)
        weights = g(X_new)
        return weights * tau_0 + (1 - weights) * tau_1

    return predict_cate

Comparison with S/T-Learners

AspectS-LearnerT-LearnerX-Learner
Models124 (2 + 2)
Data sharingFullNoneCross-group
Best whenCATE ≈ 0Different μ0,μ1\mu_0, \mu_1Imbalanced groups, smooth CATE
Rate depends onaμa_\muaμa_\muCan depend on aτa_\tau
ComplexityLowMediumHigh

When to Use

Good Scenarios

  • 그룹 크기가 불균형할 때: Weight로 조절 가능
  • CATE가 response보다 단순할 때: aτ>aμa_\tau > a_\mu
  • 구조적 가정이 있을 때: Linear CATE, smoothness 등

Bad Scenarios

  • 두 그룹 크기가 균형적이고 단순한 경우: T-learner로 충분
  • CATE가 0에 가까울 때: S-learner가 더 적합
  • 계산 비용이 문제일 때: 4개 모델 학습 필요

Weight Function Choice

g(x)g(x)의 선택:

  1. Propensity score: g(x)=1e^(x)g(x) = 1 - \hat{e}(x)

    • Control이 많으면 (e(x)e(x) 작으면) τ^0\hat{\tau}_0에 더 의존
  2. Constant: g(x)=m/(m+n)g(x) = m / (m + n)

    • 전역적 그룹 비율 사용
  3. Optimal choice: 조건에 따라 g0g \equiv 0 또는 g1g \equiv 1이 optimal

    • 한쪽 그룹이 압도적으로 클 때

Counterfactual Outcome Estimation

중요: Counterfactual outcome 추정은 Stage 1의 결과만 사용:

  • Y^i(0)=μ^0(Xi)\hat{Y}_i(0) = \hat{\mu}_0(X_i) (if treated)
  • Y^i(1)=μ^1(Xi)\hat{Y}_i(1) = \hat{\mu}_1(X_i) (if control)

τ^0,τ^1\hat{\tau}_0, \hat{\tau}_1g(x)g(x)는 CATE 추정에만 사용됨.

  • Meta-learners - 전체 framework
  • S-Learner - 대안: 단일 모델
  • T-Learner - 대안: 두 개의 별도 모델
  • R-Learner - 대안: Residualized regression
  • DR-Learner - 대안: Doubly robust pseudo-outcome
  • CATE - 추정 대상
  • Propensity Score - Weight function에 사용

Implementation

Python (econml):

from econml.metalearners import XLearner
from sklearn.ensemble import RandomForestRegressor

x_learner = XLearner(models=RandomForestRegressor())
x_learner.fit(Y, T, X=X)
cate = x_learner.effect(X_test)

R (causalToolbox):

library(causalToolbox)
x_rf <- X_RF(feat = X, tr = W, yobs = Y)
cate <- EstimateCate(x_rf, X_test)

References

  • kunzelMetalearnersEstimatingHeterogeneous2019 - X-learner 제안 및 이론적 분석

연결 그래프