Tae Hyun Kim (Lowell)

S-Learner

3분 읽기 #causal-inference#hte#meta-learner

Definition

S-Learner (Single Learner)는 treatment indicator를 feature로 포함하는 단일 모델로 response function을 추정한 후 CATE를 계산하는 Meta-learners.

Algorithm:

  1. 단일 모델로 combined response function 추정: μ^(x,w)=E^[YX=x,W=w]\hat{\mu}(x, w) = \hat{E}[Y | X = x, W = w]

  2. CATE 추정: τ^S(x)=μ^(x,1)μ^(x,0)\hat{\tau}_S(x) = \hat{\mu}(x, 1) - \hat{\mu}(x, 0)

Intuitive Understanding

핵심 아이디어:

Treatment WW를 단순히 또 하나의 feature로 취급하고, 하나의 모델로 전체 데이터를 학습

Data:  (X, W, Y) for all observations

Model: μ̂(x, w) = f(x, w)  (single model)

CATE:  τ̂(x) = μ̂(x, 1) - μ̂(x, 0)

장점:

  • 가장 간단한 접근
  • 모든 데이터를 함께 사용 (data sharing)
  • Treatment/control 간 공통 패턴 활용

단점:

  • Treatment effect가 작을 때 무시할 수 있음 (regularization이 W를 drop)
  • μ0\mu_0μ1\mu_1의 구조가 매우 다를 때 부적합

Key Properties

Data Sharing

  • 전체 데이터 (n+m)(n + m)개를 사용하여 하나의 모델 학습
  • Control/treatment 공통 패턴 학습 가능

Regularization Bias

μ^(x,w)μ^(x)if treatment effect is small\hat{\mu}(x, w) \approx \hat{\mu}(x) \quad \text{if treatment effect is small}

  • Regularization이 강할수록 WW의 영향을 무시하는 경향
  • CATE ≈ 0일 때 적합, 아닐 때 bias 발생

Convergence Rate

Response function의 smoothness aμa_\mu에 의존: Rate=O((n+m)aμ)\text{Rate} = O((n+m)^{-a_\mu})

Algorithm Detail

def s_learner(X, W, Y, base_learner):
    # Step 1: Combine treatment as feature
    X_combined = np.column_stack([X, W])

    # Step 2: Fit single model
    model = base_learner.fit(X_combined, Y)

    # Step 3: Predict CATE
    def predict_cate(X_new):
        X_treat = np.column_stack([X_new, np.ones(len(X_new))])
        X_ctrl = np.column_stack([X_new, np.zeros(len(X_new))])
        return model.predict(X_treat) - model.predict(X_ctrl)

    return predict_cate

When to Use

Good Scenarios

  • CATE가 대부분 0에 가까울 때: Regularization이 올바르게 작동
  • Response function이 유사할 때: μ0(x)μ1(x)+c\mu_0(x) \approx \mu_1(x) + c
  • 데이터가 제한적일 때: Data sharing의 이점

Bad Scenarios

  • Treatment effect가 명확할 때: Effect를 무시할 수 있음
  • Response function이 매우 다를 때: 구조적 차이 포착 어려움
  • Heterogeneous effect가 중요할 때: 미묘한 차이 놓침

Comparison with T-Learner

AspectS-LearnerT-Learner
Models12
Data usageAll togetherSplit by treatment
SharingYesNo
Best whenCATE ≈ 0Different response structures
RiskIgnore small effectsNo data sharing

Example

시뮬레이션 설정:

  • μ0(x)=x\mu_0(x) = x
  • μ1(x)=x\mu_1(x) = x (즉, τ(x)=0\tau(x) = 0)

S-Learner 결과:

  • Regularized model이 WW를 무시 → τ^(x)0\hat{\tau}(x) \approx 0
  • 올바른 추정

반대 시나리오:

  • μ0(x)=x\mu_0(x) = x, μ1(x)=x+2\mu_1(x) = x + 2 (즉, τ(x)=2\tau(x) = 2)
  • S-Learner가 regularization으로 WW의 영향 축소 가능 → bias

Implementation

Python (econml):

from econml.metalearners import SLearner
from sklearn.ensemble import RandomForestRegressor

s_learner = SLearner(overall_model=RandomForestRegressor())
s_learner.fit(Y, T, X=X)
cate = s_learner.effect(X_test)

R:

library(causalToolbox)
s_rf <- S_RF(feat = X, tr = W, yobs = Y)
cate <- EstimateCate(s_rf, X_test)

References

  • kunzelMetalearnersEstimatingHeterogeneous2019 - S-learner 분석

연결 그래프