Causal Forest
Definition
Causal Forest is a causal-inference application of the Generalized Random Forest (GRF) proposed by Athey, Tibshirani, and Wager (2019), splitting so as to maximize the heterogeneity of treatment effects.
Local estimate:
where:
- : weights computed from the tree ensemble
- : outcome prediction model
- : propensity score estimate
Intuitive Understanding
An ordinary random forest splits to predict outcomes. A Causal Forest splits to find subgroups with differing treatment effects.
In pricing, a Causal Forest answers: “Which customer characteristics determine price sensitivity?”
Key Properties
Honesty
An honest Causal Forest uses different data for deciding the tree structure and for estimation within a leaf:
- Structure decision: learn the splitting rules on one half of the data
- Effect estimation: estimate the within-leaf effect on the other half
This enables valid confidence intervals.
Asymptotic Normality
Confidence intervals and hypothesis tests are valid.
Combination with DML
CausalForestDML combines with Double Machine Learning to:
- use flexible ML for the nuisance models (outcome, treatment)
- prevent overfitting via cross-fitting
- support continuous treatments (price)
Example
Code Example
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor
forest_dml = CausalForestDML(
model_y=GradientBoostingRegressor(n_estimators=200),
model_t=GradientBoostingRegressor(n_estimators=200),
discrete_treatment=False, # 연속 가격
n_estimators=1000,
min_samples_leaf=20,
honest=True
)
forest_dml.fit(
Y=np.log1p(data['quantity']), # 로그 수량
T=np.log(data['price']), # 로그 가격 → 탄력성
X=data[heterogeneity_vars],
W=data[confounders]
)
# 개인별 탄력성
individual_elasticities = forest_dml.effect(data[heterogeneity_vars])
lower, upper = forest_dml.effect_interval(data[heterogeneity_vars], alpha=0.05)
print(f"평균 탄력성: {individual_elasticities.mean():.3f}")
print(f"탄력성 범위: [{individual_elasticities.min():.3f}, {individual_elasticities.max():.3f}]")
Segment Discovery
A Causal Forest naturally discovers segments:
from sklearn.cluster import KMeans
data['elasticity'] = forest_dml.effect(data[heterogeneity_vars])
kmeans = KMeans(n_clusters=4, random_state=42)
data['segment'] = kmeans.fit_predict(
np.column_stack([data['elasticity'], data[heterogeneity_vars]])
)
segment_profile = data.groupby('segment').agg({
'elasticity': ['mean', 'std'],
'income': 'mean',
'age': 'mean'
})
Feature Importance
importance = forest_dml.feature_importances_
for feat, imp in zip(heterogeneity_vars, importance):
print(f"{feat}: {imp:.3f}")
Related Concepts
- CATE - estimation target
- Double-Debiased ML - theoretical foundation
- Cross-fitting - overfitting prevention
- Meta-learners - alternative approach
- Policy Trees - interpretable policy learning
References
- Athey, S., Tibshirani, J., & Wager, S. (2019). “Generalized Random Forests.” Annals of Statistics.
- Wager, S., & Athey, S. (2018). “Estimation and Inference of Heterogeneous Treatment Effects using Random Forests.”
- Comprehensive Personalized Pricing Guide, Part III, §9