Policy Trees
Definition
Policy Trees, proposed by Athey & Wager (2021), are an interpretable policy-learning method.
They take individual-level treatment-effect estimates as input and learn an explicit policy rule in the form of a decision tree.
By expressing the policy as a tree, they can explain “why is this action taken for this customer?”
Intuitive Understanding
Causal Forest estimates individual-level effects, but its output is a “black box.”
A policy tree instead provides explicit rules, such as “if income is at least 50k and loyalty is at least 0.7, use the high-price tier; otherwise, if age is under 30, use the low-price tier.”
This is easy to explain to business stakeholders and makes regulatory-compliance verification straightforward.
Key Properties
Relationship to Causal Forest
- Causal Forest: estimates (individual-level effect)
- Policy Tree: learns (policy rule) based on
Objective Function
A policy tree maximizes the policy value:
- Assign treatment to people with a positive effect, and
- Withhold treatment from people with a negative effect.
Interpretability vs. Performance Trade-off
| Method | Interpretability | Performance |
|---|---|---|
| Policy Tree (depth=2) | Very high | Medium |
| Policy Tree (depth=5) | Medium | High |
| Causal Forest used directly | Low | Very high |
Example
EconML Implementation
from econml.policy import PolicyTree
from econml.dml import CausalForestDML
# Step 1: CATE 추정
forest = CausalForestDML(
model_y=GradientBoostingRegressor(),
model_t=GradientBoostingRegressor(),
n_estimators=500
)
forest.fit(Y, T, X=X, W=W)
# CATE 추정치
cate_estimates = forest.effect(X)
# Step 2: Policy Tree 학습
policy_tree = PolicyTree(
max_depth=3, # 깊이 제한 → 해석 가능성
min_samples_leaf=100 # 최소 리프 크기
)
policy_tree.fit(X, cate_estimates)
# 정책 규칙 출력
print(policy_tree.export_text(feature_names=X.columns.tolist()))
Example output:
|--- income <= 50000.00
| |--- age <= 30.00
| | |--- class: low_price
| |--- age > 30.00
| | |--- class: medium_price
|--- income > 50000.00
| |--- loyalty <= 0.70
| | |--- class: medium_price
| |--- loyalty > 0.70
| | |--- class: high_price
Multi-Level Pricing Policy
from econml.policy import PolicyTree
# 여러 가격 수준에 대한 효과 추정
price_levels = [10, 15, 20, 25, 30]
effects = {}
for price in price_levels:
# 이진화된 처리로 각 가격 효과 추정
T_binary = (T == price).astype(int)
model = CausalForestDML(...)
model.fit(Y, T_binary, X=X, W=W)
effects[price] = model.effect(X)
# 각 개인에게 최적 가격 결정
best_prices = pd.DataFrame(effects).idxmax(axis=1)
# Policy Tree로 규칙 학습
policy_tree = PolicyTree(max_depth=3)
policy_tree.fit(X, best_prices)
Policy Evaluation
def evaluate_policy(policy_tree, X, Y, T, true_effects):
"""
학습된 정책의 가치 평가
"""
# 정책 예측
recommended_treatment = policy_tree.predict(X)
# 정책을 따랐을 때의 기대 이익
# (실제로는 counterfactual이므로 AIPW 등 사용)
policy_value = np.mean(true_effects[recommended_treatment == 1])
# 무작위 정책 대비 개선
random_value = np.mean(true_effects) * 0.5 # 50% 처리
lift = (policy_value - random_value) / abs(random_value)
return {
'policy_value': policy_value,
'random_value': random_value,
'lift': lift
}
Visualization
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(20, 10))
plot_tree(
policy_tree.tree_,
feature_names=X.columns.tolist(),
class_names=['No Treatment', 'Treatment'],
filled=True,
rounded=True,
fontsize=10,
ax=ax
)
plt.title('Pricing Policy Tree')
plt.tight_layout()
plt.savefig('policy_tree.png', dpi=150)
Related Concepts
- Causal Forest - CATE estimation (input)
- CATE - estimation target
- Contextual Bandits - online policy learning
- Thompson Sampling - exploration-exploitation balance
References
- Athey, S., & Wager, S. (2021). “Policy Learning with Observational Data.”
- Zhou, Z., Athey, S., & Wager, S. (2023). “Offline Multi-Action Policy Learning.”
- Comprehensive Personalized Pricing Guide, Part IX, §26.1