Tae Hyun Kim (Lowell)

Policy Trees

3 min read #decision-making#ope#causal-forest

Definition

Policy Trees, proposed by Athey & Wager (2021), are an interpretable policy-learning method.

They take individual-level treatment-effect estimates as input and learn an explicit policy rule in the form of a decision tree.

π(x)=argmaxaE[Y(a)X=x]\pi^*(x) = \arg\max_a E[Y(a) | X = x]

By expressing the policy as a tree, they can explain “why is this action taken for this customer?”

Intuitive Understanding

Causal Forest estimates individual-level effects, but its output is a “black box.”

A policy tree instead provides explicit rules, such as “if income is at least 50k and loyalty is at least 0.7, use the high-price tier; otherwise, if age is under 30, use the low-price tier.”

This is easy to explain to business stakeholders and makes regulatory-compliance verification straightforward.

Key Properties

Relationship to Causal Forest

  1. Causal Forest: estimates τ^(x)\hat{\tau}(x) (individual-level effect)
  2. Policy Tree: learns π(x)\pi(x) (policy rule) based on τ^(x)\hat{\tau}(x)

Objective Function

A policy tree maximizes the policy value:

V(π)=E[τ(X)1[π(X)=1]]V(\pi) = E[\tau(X) \cdot \mathbf{1}[\pi(X) = 1]]

  • Assign treatment to people with a positive effect, and
  • Withhold treatment from people with a negative effect.

Interpretability vs. Performance Trade-off

MethodInterpretabilityPerformance
Policy Tree (depth=2)Very highMedium
Policy Tree (depth=5)MediumHigh
Causal Forest used directlyLowVery high

Example

EconML Implementation

from econml.policy import PolicyTree
from econml.dml import CausalForestDML

# Step 1: CATE 추정
forest = CausalForestDML(
    model_y=GradientBoostingRegressor(),
    model_t=GradientBoostingRegressor(),
    n_estimators=500
)
forest.fit(Y, T, X=X, W=W)

# CATE 추정치
cate_estimates = forest.effect(X)

# Step 2: Policy Tree 학습
policy_tree = PolicyTree(
    max_depth=3,           # 깊이 제한 → 해석 가능성
    min_samples_leaf=100   # 최소 리프 크기
)
policy_tree.fit(X, cate_estimates)

# 정책 규칙 출력
print(policy_tree.export_text(feature_names=X.columns.tolist()))

Example output:

|--- income <= 50000.00
|   |--- age <= 30.00
|   |   |--- class: low_price
|   |--- age > 30.00
|   |   |--- class: medium_price
|--- income > 50000.00
|   |--- loyalty <= 0.70
|   |   |--- class: medium_price
|   |--- loyalty > 0.70
|   |   |--- class: high_price

Multi-Level Pricing Policy

from econml.policy import PolicyTree

# 여러 가격 수준에 대한 효과 추정
price_levels = [10, 15, 20, 25, 30]
effects = {}

for price in price_levels:
    # 이진화된 처리로 각 가격 효과 추정
    T_binary = (T == price).astype(int)
    model = CausalForestDML(...)
    model.fit(Y, T_binary, X=X, W=W)
    effects[price] = model.effect(X)

# 각 개인에게 최적 가격 결정
best_prices = pd.DataFrame(effects).idxmax(axis=1)

# Policy Tree로 규칙 학습
policy_tree = PolicyTree(max_depth=3)
policy_tree.fit(X, best_prices)

Policy Evaluation

def evaluate_policy(policy_tree, X, Y, T, true_effects):
    """
    학습된 정책의 가치 평가
    """
    # 정책 예측
    recommended_treatment = policy_tree.predict(X)

    # 정책을 따랐을 때의 기대 이익
    # (실제로는 counterfactual이므로 AIPW 등 사용)
    policy_value = np.mean(true_effects[recommended_treatment == 1])

    # 무작위 정책 대비 개선
    random_value = np.mean(true_effects) * 0.5  # 50% 처리
    lift = (policy_value - random_value) / abs(random_value)

    return {
        'policy_value': policy_value,
        'random_value': random_value,
        'lift': lift
    }

Visualization

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(20, 10))
plot_tree(
    policy_tree.tree_,
    feature_names=X.columns.tolist(),
    class_names=['No Treatment', 'Treatment'],
    filled=True,
    rounded=True,
    fontsize=10,
    ax=ax
)
plt.title('Pricing Policy Tree')
plt.tight_layout()
plt.savefig('policy_tree.png', dpi=150)

References

  • Athey, S., & Wager, S. (2021). “Policy Learning with Observational Data.”
  • Zhou, Z., Athey, S., & Wager, S. (2023). “Offline Multi-Action Policy Learning.”
  • Comprehensive Personalized Pricing Guide, Part IX, §26.1

Local graph