Tae Hyun Kim (Lowell)

R-Learner

Definition

R-Learner (Residualized Learner)는 Robinson Transformation을 기반으로 residualized outcome과 residualized treatment를 사용하여 CATE를 추정하는 Meta-learners.

Algorithm:

Step 1: Nuisance functions 추정 (with Cross-fitting) m^(x)=E^[YX=x],e^(x)=P^(W=1X=x)\hat{m}(x) = \hat{E}[Y|X=x], \quad \hat{e}(x) = \hat{P}(W=1|X=x)

Step 2: R-Loss 최소화 τ^=argminτL^n(τ)+Λn(τ)\hat{\tau} = \arg\min_\tau \hat{L}_n(\tau) + \Lambda_n(\tau)

여기서: L^n(τ)=1ni=1n[{Yim^(q(i))(Xi)}{Wie^(q(i))(Xi)}τ(Xi)]2\hat{L}_n(\tau) = \frac{1}{n} \sum_{i=1}^n \left[ \{Y_i - \hat{m}^{(-q(i))}(X_i)\} - \{W_i - \hat{e}^{(-q(i))}(X_i)\} \tau(X_i) \right]^2

  • m^(q(i))\hat{m}^{(-q(i))}: ii번째 관측치를 제외한 fold에서 추정된 m^\hat{m}
  • Λn(τ)\Lambda_n(\tau): Regularization term

Intuitive Understanding

핵심 아이디어:

Outcome과 treatment 모두에서 covariates의 영향을 제거한 후, 순수한 treatment effect만 학습

Step 1: Estimate nuisance functions
        m̂(x) = E[Y|X]    (outcome model)
        ê(x) = P(W=1|X)  (propensity model)

Step 2: Compute residuals (via cross-fitting)
        Ỹᵢ = Yᵢ - m̂(Xᵢ)        (outcome residual)
        W̃ᵢ = Wᵢ - ê(Xᵢ)        (treatment residual)

Step 3: Minimize R-loss
        τ̂ = argmin Σ[Ỹᵢ - W̃ᵢ·τ(Xᵢ)]² + regularization

왜 “R”인가?

  • Residuals 사용
  • Robinson transformation 기반
  • 또는 저자 이름의 이니셜

Key Properties

Quasi-Oracle Property

핵심 정리 (Theorem 1): Nuisance components가 o(n1/4)o(n^{-1/4}) rate로 추정되면, R-learner는 true nuisance functions를 아는 oracle과 동일한 convergence rate 달성.

τ^τ2=OP(lognn)+oP(1)nuisance error\|\hat{\tau} - \tau^*\|^2 = O_P\left(\frac{\log n}{n}\right) + o_P(1) \cdot \text{nuisance error}

의미:

  • Nuisance estimation의 오차가 1차적으로 영향 없음
  • 느린 nuisance estimation도 괜찮음 (n1/4n^{-1/4}만 충족하면 됨)

Orthogonality

Robinson Transformation의 orthogonality condition: E[(Ym(X))(We(X))X]=0E[(Y - m^*(X)) \cdot (W - e^*(X)) | X] = 0

이로 인해 nuisance error에 대한 robustness 확보.

Separation of Concerns

  1. Confounding control: Step 1에서 m^,e^\hat{m}, \hat{e} 추정
  2. Treatment effect estimation: Step 2에서 순수하게 CATE에 집중

각 단계에서 다른 ML method 사용 가능.

Algorithm Detail

def r_learner(X, W, Y, base_learner, n_folds=5):
    from sklearn.model_selection import KFold

    n = len(Y)
    m_hat = np.zeros(n)  # outcome residuals
    e_hat = np.zeros(n)  # treatment residuals

    # Step 1: Cross-fitted nuisance estimation
    kf = KFold(n_splits=n_folds, shuffle=True)

    for train_idx, val_idx in kf.split(X):
        # Fit outcome model
        outcome_model = base_learner.fit(X[train_idx], Y[train_idx])
        m_hat[val_idx] = outcome_model.predict(X[val_idx])

        # Fit propensity model
        propensity_model = base_learner.fit(X[train_idx], W[train_idx])
        e_hat[val_idx] = propensity_model.predict(X[val_idx])

    # Compute residuals
    Y_tilde = Y - m_hat  # outcome residual
    W_tilde = W - e_hat  # treatment residual

    # Step 2: Minimize R-loss
    # τ(x) = argmin Σ(Ỹᵢ - W̃ᵢ·τ(Xᵢ))²
    # This is equivalent to weighted least squares:
    # Ỹᵢ/W̃ᵢ ≈ τ(Xᵢ) with weight W̃ᵢ²

    # Pseudo-outcome for regression
    pseudo_outcome = Y_tilde / np.clip(W_tilde, 1e-6, None)
    weights = W_tilde ** 2

    # Fit CATE model (weighted regression)
    tau_model = base_learner.fit(X, pseudo_outcome, sample_weight=weights)

    return tau_model.predict

Comparison with Other Meta-Learners

AspectS-LearnerT-LearnerX-LearnerR-Learner
Models1243 (m, e, τ)
Data usageAll togetherSplitCross-groupAll + cross-fitting
TargetsResponseResponseImputed effectsResidualized outcome
Key featureSimpleSeparateImbalance handlingOrthogonality
Best whenCATE ≈ 0Different μ₀, μ₁Unbalanced groupsSimple CATE, complex nuisance

When to Use

Good Scenarios

  • CATE가 nuisance보다 단순할 때: Orthogonality로 nuisance complexity 영향 최소화
  • Confounding이 복잡하지만 treatment effect는 간단할 때
  • Cross-validation이 중요할 때: 각 단계별 hyperparameter tuning 가능

Bad Scenarios

  • Propensity score가 극단적일 때: e(x)0e(x) \approx 0 또는 11이면 불안정
  • Sample size가 매우 작을 때: Cross-fitting으로 데이터 손실
  • Nuisance estimation이 어려울 때: n1/4n^{-1/4} rate 미달성 시 보장 없음

Rate Conditions

Quasi-Oracle 달성 조건

m^m2e^e2=oP(n1/2)\|\hat{m} - m^*\|_2 \cdot \|\hat{e} - e^*\|_2 = o_P(n^{-1/2})

또는 equivalently: m^m2=oP(n1/4),e^e2=oP(n1/4)\|\hat{m} - m^*\|_2 = o_P(n^{-1/4}), \quad \|\hat{e} - e^*\|_2 = o_P(n^{-1/4})

Convergence Rate

적절한 regularity 조건 하에서:

  • Penalized kernel regression: O(nα/(2α+d))O(n^{-\alpha/(2\alpha+d)}) where α\alpha is smoothness
  • Linear CATE: Parametric rate O(n1/2)O(n^{-1/2}) 가능

Simulation Results (from Paper)

Setup A (Complex nuisance, simple CATE):

  • R-learner 최강 성능
  • Orthogonality가 confounding 효과 제거

Setup B (RCT, constant propensity):

  • R-learner ≈ T-learner
  • 특별한 이점 없음

Setup C (Easy propensity, complex baseline):

  • R-learner 경쟁력 있음
  • X-learner와 비슷한 성능

Setup D (Unrelated arms):

  • T-learner가 유리
  • R-learner는 data sharing의 이점 못 봄
  • Meta-learners - 전체 framework
  • Robinson Transformation - 이론적 기반
  • R-Loss - 최적화 목적 함수
  • Quasi-Oracle Property - 핵심 이론적 보장
  • Cross-fitting - Overfitting bias 제거
  • S-Learner, T-Learner, X-Learner - 대안 방법들
  • DR-Learner - Doubly robust 접근
  • CATE - 추정 대상
  • Propensity Score - Nuisance component

Implementation

Python (econml):

from econml.dml import NonParamDML
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

# R-learner is closely related to DML
r_learner = NonParamDML(
    model_y=RandomForestRegressor(),
    model_t=RandomForestClassifier(),
    model_final=RandomForestRegressor(),
    cv=5  # cross-fitting folds
)
r_learner.fit(Y, T, X=X)
cate = r_learner.effect(X_test)

R (rlearner package):

library(rlearner)

# Using Random Forest as base learner
r_rf <- rlasso(X, W, Y)  # or rboost, etc.
cate <- predict(r_rf, X_test)

References

  • nieQuasiOracleEstimationHeterogeneous2020 - R-learner 원논문
  • chernozhukovDoubleDebiasedMachine2018 - 관련 DML 이론

연결 그래프