Tae Hyun Kim (Lowell)

Causal Forest

2 min read #causal-inference#hte#causal-forest

Definition

Causal Forest is a causal-inference application of the Generalized Random Forest (GRF) proposed by Athey, Tibshirani, and Wager (2019), splitting so as to maximize the heterogeneity of treatment effects.

Local estimate: τ^(x)=iαi(x)(Yiμ^(Xi))(Tie^(Xi))iαi(x)(Tie^(Xi))2\hat{\tau}(x) = \frac{\sum_i \alpha_i(x) (Y_i - \hat{\mu}(X_i)) (T_i - \hat{e}(X_i))}{\sum_i \alpha_i(x) (T_i - \hat{e}(X_i))^2}

where:

  • αi(x)\alpha_i(x): weights computed from the tree ensemble
  • μ^(Xi)\hat{\mu}(X_i): outcome prediction model
  • e^(Xi)\hat{e}(X_i): propensity score estimate

Intuitive Understanding

An ordinary random forest splits to predict outcomes. A Causal Forest splits to find subgroups with differing treatment effects.

In pricing, a Causal Forest answers: “Which customer characteristics determine price sensitivity?”

Key Properties

Honesty

An honest Causal Forest uses different data for deciding the tree structure and for estimation within a leaf:

  1. Structure decision: learn the splitting rules on one half of the data
  2. Effect estimation: estimate the within-leaf effect on the other half

This enables valid confidence intervals.

Asymptotic Normality

n(τ^(x)τ(x))dN(0,V(x))\sqrt{n}(\hat{\tau}(x) - \tau(x)) \xrightarrow{d} N(0, V(x))

Confidence intervals and hypothesis tests are valid.

Combination with DML

CausalForestDML combines with Double Machine Learning to:

  • use flexible ML for the nuisance models (outcome, treatment)
  • prevent overfitting via cross-fitting
  • support continuous treatments (price)

Example

Code Example

from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor

forest_dml = CausalForestDML(
    model_y=GradientBoostingRegressor(n_estimators=200),
    model_t=GradientBoostingRegressor(n_estimators=200),
    discrete_treatment=False,  # 연속 가격
    n_estimators=1000,
    min_samples_leaf=20,
    honest=True
)

forest_dml.fit(
    Y=np.log1p(data['quantity']),  # 로그 수량
    T=np.log(data['price']),        # 로그 가격 → 탄력성
    X=data[heterogeneity_vars],
    W=data[confounders]
)

# 개인별 탄력성
individual_elasticities = forest_dml.effect(data[heterogeneity_vars])
lower, upper = forest_dml.effect_interval(data[heterogeneity_vars], alpha=0.05)

print(f"평균 탄력성: {individual_elasticities.mean():.3f}")
print(f"탄력성 범위: [{individual_elasticities.min():.3f}, {individual_elasticities.max():.3f}]")

Segment Discovery

A Causal Forest naturally discovers segments:

from sklearn.cluster import KMeans

data['elasticity'] = forest_dml.effect(data[heterogeneity_vars])

kmeans = KMeans(n_clusters=4, random_state=42)
data['segment'] = kmeans.fit_predict(
    np.column_stack([data['elasticity'], data[heterogeneity_vars]])
)

segment_profile = data.groupby('segment').agg({
    'elasticity': ['mean', 'std'],
    'income': 'mean',
    'age': 'mean'
})

Feature Importance

importance = forest_dml.feature_importances_
for feat, imp in zip(heterogeneity_vars, importance):
    print(f"{feat}: {imp:.3f}")

References

  • Athey, S., Tibshirani, J., & Wager, S. (2019). “Generalized Random Forests.” Annals of Statistics.
  • Wager, S., & Athey, S. (2018). “Estimation and Inference of Heterogeneous Treatment Effects using Random Forests.”
  • Comprehensive Personalized Pricing Guide, Part III, §9

Local graph