R-Learner
Definition
R-Learner (Residualized Learner)는 Robinson Transformation을 기반으로 residualized outcome과 residualized treatment를 사용하여 CATE를 추정하는 Meta-learners.
Algorithm:
Step 1: Nuisance functions 추정 (with Cross-fitting)
Step 2: R-Loss 최소화
여기서:
- : 번째 관측치를 제외한 fold에서 추정된
- : Regularization term
Intuitive Understanding
핵심 아이디어:
Outcome과 treatment 모두에서 covariates의 영향을 제거한 후, 순수한 treatment effect만 학습
Step 1: Estimate nuisance functions
m̂(x) = E[Y|X] (outcome model)
ê(x) = P(W=1|X) (propensity model)
↓
Step 2: Compute residuals (via cross-fitting)
Ỹᵢ = Yᵢ - m̂(Xᵢ) (outcome residual)
W̃ᵢ = Wᵢ - ê(Xᵢ) (treatment residual)
↓
Step 3: Minimize R-loss
τ̂ = argmin Σ[Ỹᵢ - W̃ᵢ·τ(Xᵢ)]² + regularization
왜 “R”인가?
- Residuals 사용
- Robinson transformation 기반
- 또는 저자 이름의 이니셜
Key Properties
Quasi-Oracle Property
핵심 정리 (Theorem 1): Nuisance components가 rate로 추정되면, R-learner는 true nuisance functions를 아는 oracle과 동일한 convergence rate 달성.
의미:
- Nuisance estimation의 오차가 1차적으로 영향 없음
- 느린 nuisance estimation도 괜찮음 (만 충족하면 됨)
Orthogonality
Robinson Transformation의 orthogonality condition:
이로 인해 nuisance error에 대한 robustness 확보.
Separation of Concerns
- Confounding control: Step 1에서 추정
- Treatment effect estimation: Step 2에서 순수하게 CATE에 집중
각 단계에서 다른 ML method 사용 가능.
Algorithm Detail
def r_learner(X, W, Y, base_learner, n_folds=5):
from sklearn.model_selection import KFold
n = len(Y)
m_hat = np.zeros(n) # outcome residuals
e_hat = np.zeros(n) # treatment residuals
# Step 1: Cross-fitted nuisance estimation
kf = KFold(n_splits=n_folds, shuffle=True)
for train_idx, val_idx in kf.split(X):
# Fit outcome model
outcome_model = base_learner.fit(X[train_idx], Y[train_idx])
m_hat[val_idx] = outcome_model.predict(X[val_idx])
# Fit propensity model
propensity_model = base_learner.fit(X[train_idx], W[train_idx])
e_hat[val_idx] = propensity_model.predict(X[val_idx])
# Compute residuals
Y_tilde = Y - m_hat # outcome residual
W_tilde = W - e_hat # treatment residual
# Step 2: Minimize R-loss
# τ(x) = argmin Σ(Ỹᵢ - W̃ᵢ·τ(Xᵢ))²
# This is equivalent to weighted least squares:
# Ỹᵢ/W̃ᵢ ≈ τ(Xᵢ) with weight W̃ᵢ²
# Pseudo-outcome for regression
pseudo_outcome = Y_tilde / np.clip(W_tilde, 1e-6, None)
weights = W_tilde ** 2
# Fit CATE model (weighted regression)
tau_model = base_learner.fit(X, pseudo_outcome, sample_weight=weights)
return tau_model.predict
Comparison with Other Meta-Learners
| Aspect | S-Learner | T-Learner | X-Learner | R-Learner |
|---|---|---|---|---|
| Models | 1 | 2 | 4 | 3 (m, e, τ) |
| Data usage | All together | Split | Cross-group | All + cross-fitting |
| Targets | Response | Response | Imputed effects | Residualized outcome |
| Key feature | Simple | Separate | Imbalance handling | Orthogonality |
| Best when | CATE ≈ 0 | Different μ₀, μ₁ | Unbalanced groups | Simple CATE, complex nuisance |
When to Use
Good Scenarios
- CATE가 nuisance보다 단순할 때: Orthogonality로 nuisance complexity 영향 최소화
- Confounding이 복잡하지만 treatment effect는 간단할 때
- Cross-validation이 중요할 때: 각 단계별 hyperparameter tuning 가능
Bad Scenarios
- Propensity score가 극단적일 때: 또는 이면 불안정
- Sample size가 매우 작을 때: Cross-fitting으로 데이터 손실
- Nuisance estimation이 어려울 때: rate 미달성 시 보장 없음
Rate Conditions
Quasi-Oracle 달성 조건
또는 equivalently:
Convergence Rate
적절한 regularity 조건 하에서:
- Penalized kernel regression: where is smoothness
- Linear CATE: Parametric rate 가능
Simulation Results (from Paper)
Setup A (Complex nuisance, simple CATE):
- R-learner 최강 성능
- Orthogonality가 confounding 효과 제거
Setup B (RCT, constant propensity):
- R-learner ≈ T-learner
- 특별한 이점 없음
Setup C (Easy propensity, complex baseline):
- R-learner 경쟁력 있음
- X-learner와 비슷한 성능
Setup D (Unrelated arms):
- T-learner가 유리
- R-learner는 data sharing의 이점 못 봄
Related Concepts
- Meta-learners - 전체 framework
- Robinson Transformation - 이론적 기반
- R-Loss - 최적화 목적 함수
- Quasi-Oracle Property - 핵심 이론적 보장
- Cross-fitting - Overfitting bias 제거
- S-Learner, T-Learner, X-Learner - 대안 방법들
- DR-Learner - Doubly robust 접근
- CATE - 추정 대상
- Propensity Score - Nuisance component
Implementation
Python (econml):
from econml.dml import NonParamDML
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
# R-learner is closely related to DML
r_learner = NonParamDML(
model_y=RandomForestRegressor(),
model_t=RandomForestClassifier(),
model_final=RandomForestRegressor(),
cv=5 # cross-fitting folds
)
r_learner.fit(Y, T, X=X)
cate = r_learner.effect(X_test)
R (rlearner package):
library(rlearner)
# Using Random Forest as base learner
r_rf <- rlasso(X, W, Y) # or rboost, etc.
cate <- predict(r_rf, X_test)
References
- nieQuasiOracleEstimationHeterogeneous2020 - R-learner 원논문
- chernozhukovDoubleDebiasedMachine2018 - 관련 DML 이론