Tae Hyun Kim (Lowell)

T-Learner

3분 읽기 #causal-inference#hte#meta-learner

Definition

T-Learner (Two Learner)는 treatment group과 control group에 대해 별도의 모델을 학습하여 CATE를 추정하는 Meta-learners.

Algorithm:

  1. Control group에서 μ0(x)\mu_0(x) 추정: μ^0(x)=E^[YX=x,W=0]\hat{\mu}_0(x) = \hat{E}[Y | X = x, W = 0]

  2. Treatment group에서 μ1(x)\mu_1(x) 추정: μ^1(x)=E^[YX=x,W=1]\hat{\mu}_1(x) = \hat{E}[Y | X = x, W = 1]

  3. CATE 추정: τ^T(x)=μ^1(x)μ^0(x)\hat{\tau}_T(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x)

Intuitive Understanding

핵심 아이디어:

두 그룹을 완전히 분리하여 각각의 response function을 독립적으로 학습

Control data:  (X₀, Y₀) → μ̂₀(x)
Treatment data: (X₁, Y₁) → μ̂₁(x)

CATE:          τ̂(x) = μ̂₁(x) - μ̂₀(x)

장점:

  • 각 그룹의 고유한 response 구조 포착
  • μ0\mu_0μ1\mu_1이 매우 다를 때 적합
  • 개념적으로 명확

단점:

  • 데이터 공유 없음 (각 모델이 절반 데이터만 사용)
  • CATE가 단순해도 response function의 복잡성에 rate 의존
  • 그룹 크기가 불균형하면 비효율적

Key Properties

No Data Sharing

  • 각 모델이 해당 그룹의 데이터만 사용
  • Control: mm개, Treatment: nn
  • 공통 패턴 학습 불가

Rate Depends on Response Functions

Rate=O(maμ+naμ)\text{Rate} = O(m^{-a_\mu} + n^{-a_\mu})

  • aμa_\mu: response function의 smoothness
  • CATE가 단순해도 (aτ>aμa_\tau > a_\mu), rate는 aμa_\mu에 의존

Minimax Optimal (Theorem 7)

특정 조건 하에서 T-learner는 minimax rate optimal

Algorithm Detail

def t_learner(X, W, Y, base_learner):
    # Split data by treatment
    X_ctrl, Y_ctrl = X[W == 0], Y[W == 0]
    X_treat, Y_treat = X[W == 1], Y[W == 1]

    # Step 1: Fit control model
    model_0 = base_learner.fit(X_ctrl, Y_ctrl)

    # Step 2: Fit treatment model
    model_1 = base_learner.fit(X_treat, Y_treat)

    # Step 3: Predict CATE
    def predict_cate(X_new):
        return model_1.predict(X_new) - model_0.predict(X_new)

    return predict_cate

When to Use

Good Scenarios

  • Response function이 매우 다를 때: μ0(x)\mu_0(x)μ1(x)\mu_1(x)의 구조가 다름
  • 그룹 크기가 균형적일 때: 각 모델이 충분한 데이터
  • Treatment effect가 복잡할 때: 각 그룹의 복잡성을 별도로 모델링

Bad Scenarios

  • CATE가 단순하지만 response가 복잡할 때: Rate가 불필요하게 느림
  • 그룹 크기가 불균형할 때: 작은 그룹의 추정이 부정확
  • 공통 패턴이 많을 때: Data sharing의 이점을 잃음

Comparison with S-Learner

AspectT-LearnerS-Learner
Models2 (separate)1 (combined)
Data per modelmm or nnm+nm + n
StructureCaptures different responsesAssumes similar responses
RiskNo data sharingMay ignore treatment effect

Example

시뮬레이션 설정:

  • μ0(x)=sin(x)\mu_0(x) = \sin(x) (복잡)
  • μ1(x)=cos(x)\mu_1(x) = \cos(x) (복잡, 다른 패턴)
  • τ(x)=cos(x)sin(x)\tau(x) = \cos(x) - \sin(x)

T-Learner:

  • 각 response function을 잘 포착 ✓
  • CATE 추정 양호

S-Learner:

  • 두 패턴의 평균을 학습
  • 각 그룹의 고유 패턴 놓침

Variance Analysis

CATE estimator의 variance: Var(τ^T(x))=Var(μ^1(x))+Var(μ^0(x))\text{Var}(\hat{\tau}_T(x)) = \text{Var}(\hat{\mu}_1(x)) + \text{Var}(\hat{\mu}_0(x))

각 모델의 variance가 독립적으로 기여 → 데이터 분할로 인해 각 variance 증가

Implementation

Python (econml):

from econml.metalearners import TLearner
from sklearn.ensemble import RandomForestRegressor

t_learner = TLearner(models=RandomForestRegressor())
t_learner.fit(Y, T, X=X)
cate = t_learner.effect(X_test)

R:

library(causalToolbox)
t_rf <- T_RF(feat = X, tr = W, yobs = Y)
cate <- EstimateCate(t_rf, X_test)

References

  • kunzelMetalearnersEstimatingHeterogeneous2019 - T-learner 분석 및 minimax optimality

연결 그래프