Tae Hyun Kim (Lowell)

X-Learner

5 min read #causal-inference#hte#meta-learner

Definition

The X-Learner is a three-stage algorithm that leverages imputed treatment effects, a Meta-learners that effectively exploits group imbalance and the structural properties of the CATE.

Algorithm:

Stage 1: Estimate response functions μ^0(x)=E^[YX=x,W=0]\hat{\mu}_0(x) = \hat{E}[Y | X = x, W = 0] μ^1(x)=E^[YX=x,W=1]\hat{\mu}_1(x) = \hat{E}[Y | X = x, W = 1]

Stage 2: Compute imputed treatment effects and estimate the CATE D~1i:=Y1iμ^0(X1i)(treatment group)\tilde{D}_{1i} := Y_{1i} - \hat{\mu}_0(X_{1i}) \quad \text{(treatment group)} D~0i:=μ^1(X0i)Y0i(control group)\tilde{D}_{0i} := \hat{\mu}_1(X_{0i}) - Y_{0i} \quad \text{(control group)}

Estimate τ(x)\tau(x) in each group:

  • τ^1(x)\hat{\tau}_1(x): Learned from the imputed effects of the treatment group
  • τ^0(x)\hat{\tau}_0(x): Learned from the imputed effects of the control group

Stage 3: Weighted combination τ^X(x)=g(x)τ^0(x)+(1g(x))τ^1(x)\hat{\tau}_X(x) = g(x)\hat{\tau}_0(x) + (1 - g(x))\hat{\tau}_1(x)

where g(x)[0,1]g(x) \in [0, 1] is a weight function (typically the propensity score e^(x)\hat{e}(x) is used).

Intuitive Understanding

Core idea:

Using the observed outcome and the predicted value from the opposite group, create “imputed” treatment effects, estimate the CATE from the perspective of each group, and then combine them.

Stage 1: Estimate response functions (like T-learner)
           μ̂₀(x), μ̂₁(x)

Stage 2: Impute treatment effects
  Treatment group: D̃₁ᵢ = Y₁ᵢ - μ̂₀(X₁ᵢ)  (observed - predicted control)
  Control group:   D̃₀ᵢ = μ̂₁(X₀ᵢ) - Y₀ᵢ  (predicted treatment - observed)

         Train τ̂₁(x) on D̃₁, τ̂₀(x) on D̃₀

Stage 3: Weighted combination
         τ̂(x) = g(x)·τ̂₀(x) + (1-g(x))·τ̂₁(x)

Why “X”?

  • The information from the treatment group is used to estimate the CATE of the control group, and vice versa
  • Information is transmitted by “crossing”

Key Properties

Exploits CATE Structure

  • When the CATE is simple (e.g., linear), this structure can be leveraged
  • Even if the response function is complex, a fast rate is achieved if the CATE is simple

Handles Imbalanced Groups

  • The weight function g(x)g(x) adjusts for group size imbalance
  • Makes greater use of the information from the larger group

Convergence Rates (Conjecture 1)

Under conditions:

  • τ^0\hat{\tau}_0: O(maτ+naμ)O(m^{-a_\tau} + n^{-a_\mu})
  • τ^1\hat{\tau}_1: O(maμ+naτ)O(m^{-a_\mu} + n^{-a_\tau})

where:

  • aμa_\mu: Smoothness of the response function
  • aτa_\tau: Smoothness of the CATE function
  • mm: Control group size, nn: Treatment group size

Parametric Rate (Theorem 2)

Conditions:

  • Linear CATE: τ(x)=xTβ\tau(x) = x^T \beta
  • Lipschitz response functions
  • mc3n1/am \geq c_3 n^{1/a}

Result: X-learner with g0g \equiv 0 achieves parametric rate O(n1)O(n^{-1}) in treatment group size.

Algorithm Detail

def x_learner(X, W, Y, base_learner, propensity_model=None):
    # Split data by treatment
    X_ctrl, Y_ctrl = X[W == 0], Y[W == 0]
    X_treat, Y_treat = X[W == 1], Y[W == 1]

    # Stage 1: Estimate response functions
    model_0 = base_learner.fit(X_ctrl, Y_ctrl)
    model_1 = base_learner.fit(X_treat, Y_treat)

    # Stage 2: Compute imputed treatment effects
    # For treatment group: observed - predicted control
    D_tilde_1 = Y_treat - model_0.predict(X_treat)
    # For control group: predicted treatment - observed
    D_tilde_0 = model_1.predict(X_ctrl) - Y_ctrl

    # Fit CATE models on imputed effects
    tau_model_1 = base_learner.fit(X_treat, D_tilde_1)
    tau_model_0 = base_learner.fit(X_ctrl, D_tilde_0)

    # Stage 3: Estimate propensity score for weighting
    if propensity_model is None:
        # Simple estimate or use provided
        g = lambda x: len(X_ctrl) / len(X)
    else:
        propensity_model.fit(X, W)
        g = lambda x: 1 - propensity_model.predict_proba(x)[:, 1]

    # Weighted combination
    def predict_cate(X_new):
        tau_0 = tau_model_0.predict(X_new)
        tau_1 = tau_model_1.predict(X_new)
        weights = g(X_new)
        return weights * tau_0 + (1 - weights) * tau_1

    return predict_cate

Comparison with S/T-Learners

AspectS-LearnerT-LearnerX-Learner
Models124 (2 + 2)
Data sharingFullNoneCross-group
Best whenCATE ≈ 0Different μ0,μ1\mu_0, \mu_1Imbalanced groups, smooth CATE
Rate depends onaμa_\muaμa_\muCan depend on aτa_\tau
ComplexityLowMediumHigh

When to Use

Good Scenarios

  • When group sizes are imbalanced: Adjustable via weights
  • When the CATE is simpler than the response: aτ>aμa_\tau > a_\mu
  • When there are structural assumptions: Linear CATE, smoothness, etc.

Bad Scenarios

  • When the two group sizes are balanced and simple: The T-learner suffices
  • When the CATE is close to 0: The S-learner is more suitable
  • When computational cost is an issue: Requires training 4 models

Weight Function Choice

The choice of g(x)g(x):

  1. Propensity score: g(x)=1e^(x)g(x) = 1 - \hat{e}(x)

    • When there are many controls (small e(x)e(x)), rely more on τ^0\hat{\tau}_0
  2. Constant: g(x)=m/(m+n)g(x) = m / (m + n)

    • Uses the global group ratio
  3. Optimal choice: depending on conditions, g0g \equiv 0 or g1g \equiv 1 is optimal

    • When one group is overwhelmingly large

Counterfactual Outcome Estimation

Important: Counterfactual outcome estimation uses only the results of Stage 1:

  • Y^i(0)=μ^0(Xi)\hat{Y}_i(0) = \hat{\mu}_0(X_i) (if treated)
  • Y^i(1)=μ^1(Xi)\hat{Y}_i(1) = \hat{\mu}_1(X_i) (if control)

τ^0,τ^1\hat{\tau}_0, \hat{\tau}_1 and g(x)g(x) are used only for CATE estimation.

  • Meta-learners - The overall framework
  • S-Learner - Alternative: single model
  • T-Learner - Alternative: two separate models
  • R-Learner - Alternative: residualized regression
  • DR-Learner - Alternative: doubly robust pseudo-outcome
  • CATE - Estimation target
  • Propensity Score - Used in the weight function

Implementation

Python (econml):

from econml.metalearners import XLearner
from sklearn.ensemble import RandomForestRegressor

x_learner = XLearner(models=RandomForestRegressor())
x_learner.fit(Y, T, X=X)
cate = x_learner.effect(X_test)

R (causalToolbox):

library(causalToolbox)
x_rf <- X_RF(feat = X, tr = W, yobs = Y)
cate <- EstimateCate(x_rf, X_test)

References

  • kunzelMetalearnersEstimatingHeterogeneous2019 - X-learner proposal and theoretical analysis

Local graph