X-Learner
Definition
The X-Learner is a three-stage algorithm that leverages imputed treatment effects, a Meta-learners that effectively exploits group imbalance and the structural properties of the CATE.
Algorithm:
Stage 1: Estimate response functions
Stage 2: Compute imputed treatment effects and estimate the CATE
Estimate in each group:
- : Learned from the imputed effects of the treatment group
- : Learned from the imputed effects of the control group
Stage 3: Weighted combination
where is a weight function (typically the propensity score is used).
Intuitive Understanding
Core idea:
Using the observed outcome and the predicted value from the opposite group, create “imputed” treatment effects, estimate the CATE from the perspective of each group, and then combine them.
Stage 1: Estimate response functions (like T-learner)
μ̂₀(x), μ̂₁(x)
↓
Stage 2: Impute treatment effects
Treatment group: D̃₁ᵢ = Y₁ᵢ - μ̂₀(X₁ᵢ) (observed - predicted control)
Control group: D̃₀ᵢ = μ̂₁(X₀ᵢ) - Y₀ᵢ (predicted treatment - observed)
↓
Train τ̂₁(x) on D̃₁, τ̂₀(x) on D̃₀
↓
Stage 3: Weighted combination
τ̂(x) = g(x)·τ̂₀(x) + (1-g(x))·τ̂₁(x)
Why “X”?
- The information from the treatment group is used to estimate the CATE of the control group, and vice versa
- Information is transmitted by “crossing”
Key Properties
Exploits CATE Structure
- When the CATE is simple (e.g., linear), this structure can be leveraged
- Even if the response function is complex, a fast rate is achieved if the CATE is simple
Handles Imbalanced Groups
- The weight function adjusts for group size imbalance
- Makes greater use of the information from the larger group
Convergence Rates (Conjecture 1)
Under conditions:
- :
- :
where:
- : Smoothness of the response function
- : Smoothness of the CATE function
- : Control group size, : Treatment group size
Parametric Rate (Theorem 2)
Conditions:
- Linear CATE:
- Lipschitz response functions
Result: X-learner with achieves parametric rate in treatment group size.
Algorithm Detail
def x_learner(X, W, Y, base_learner, propensity_model=None):
# Split data by treatment
X_ctrl, Y_ctrl = X[W == 0], Y[W == 0]
X_treat, Y_treat = X[W == 1], Y[W == 1]
# Stage 1: Estimate response functions
model_0 = base_learner.fit(X_ctrl, Y_ctrl)
model_1 = base_learner.fit(X_treat, Y_treat)
# Stage 2: Compute imputed treatment effects
# For treatment group: observed - predicted control
D_tilde_1 = Y_treat - model_0.predict(X_treat)
# For control group: predicted treatment - observed
D_tilde_0 = model_1.predict(X_ctrl) - Y_ctrl
# Fit CATE models on imputed effects
tau_model_1 = base_learner.fit(X_treat, D_tilde_1)
tau_model_0 = base_learner.fit(X_ctrl, D_tilde_0)
# Stage 3: Estimate propensity score for weighting
if propensity_model is None:
# Simple estimate or use provided
g = lambda x: len(X_ctrl) / len(X)
else:
propensity_model.fit(X, W)
g = lambda x: 1 - propensity_model.predict_proba(x)[:, 1]
# Weighted combination
def predict_cate(X_new):
tau_0 = tau_model_0.predict(X_new)
tau_1 = tau_model_1.predict(X_new)
weights = g(X_new)
return weights * tau_0 + (1 - weights) * tau_1
return predict_cate
Comparison with S/T-Learners
| Aspect | S-Learner | T-Learner | X-Learner |
|---|---|---|---|
| Models | 1 | 2 | 4 (2 + 2) |
| Data sharing | Full | None | Cross-group |
| Best when | CATE ≈ 0 | Different | Imbalanced groups, smooth CATE |
| Rate depends on | Can depend on | ||
| Complexity | Low | Medium | High |
When to Use
Good Scenarios
- When group sizes are imbalanced: Adjustable via weights
- When the CATE is simpler than the response:
- When there are structural assumptions: Linear CATE, smoothness, etc.
Bad Scenarios
- When the two group sizes are balanced and simple: The T-learner suffices
- When the CATE is close to 0: The S-learner is more suitable
- When computational cost is an issue: Requires training 4 models
Weight Function Choice
The choice of :
-
Propensity score:
- When there are many controls (small ), rely more on
-
Constant:
- Uses the global group ratio
-
Optimal choice: depending on conditions, or is optimal
- When one group is overwhelmingly large
Counterfactual Outcome Estimation
Important: Counterfactual outcome estimation uses only the results of Stage 1:
- (if treated)
- (if control)
and are used only for CATE estimation.
Related Concepts
- Meta-learners - The overall framework
- S-Learner - Alternative: single model
- T-Learner - Alternative: two separate models
- R-Learner - Alternative: residualized regression
- DR-Learner - Alternative: doubly robust pseudo-outcome
- CATE - Estimation target
- Propensity Score - Used in the weight function
Implementation
Python (econml):
from econml.metalearners import XLearner
from sklearn.ensemble import RandomForestRegressor
x_learner = XLearner(models=RandomForestRegressor())
x_learner.fit(Y, T, X=X)
cate = x_learner.effect(X_test)
R (causalToolbox):
library(causalToolbox)
x_rf <- X_RF(feat = X, tr = W, yobs = Y)
cate <- EstimateCate(x_rf, X_test)
References
- kunzelMetalearnersEstimatingHeterogeneous2019 - X-learner proposal and theoretical analysis