CFR (Counterfactual Regression)
Definition
A deep learning method that learns balanced representations via IPM (Integral Probability Metric) regularization
Proposed by Shalit et al. (2017).
Model Architecture
TARNet (Treatment-Agnostic Representation Network)
X → [Representation Network] → Φ(X)
│
┌────────────┼────────────┐
▼ │ ▼
┌──────────┐ │ ┌──────────┐
│ h₀(Φ(X)) │ │ │ h₁(Φ(X)) │
│ Control │ │ │ Treated │
└──────────┘ │ └──────────┘
│
[IPM Regularization]
Loss Function
IPM Choices
1. CFR-MMD (Maximum Mean Discrepancy)
RBF kernel:
2. CFR-Wasserstein
Implementation: Approximated with a discriminator (WGAN style)
Theoretical Guarantees
Generalization Bound
Shalit et al. (2017):
- The PEHE error is bounded by the factual error and the IPM
- Motivates IPM minimization
Trade-off
The choice of is critical.
Advantages and Disadvantages
Advantages
| Advantage | Description |
|---|---|
| Theoretical guarantee | Generalization bound |
| End-to-end | Joint learning of representation + prediction |
| Flexibility | Various network architectures |
| Scalability | Handles high-dimensional data |
Disadvantages
| Disadvantage | Description |
|---|---|
| Choice of | Sensitive to the hyperparameter |
| Large-scale data | Inherent to deep learning |
| Uncertainty | Hard to provide confidence intervals |
| IPM computation | Especially Wasserstein |
Implementation
Python (PyTorch)
import torch
import torch.nn as nn
class CFRNet(nn.Module):
def __init__(self, input_dim, hidden_dim=100, repr_dim=50):
super().__init__()
# Representation network
self.repr_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, repr_dim)
)
# Outcome networks
self.head_0 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 1)
)
self.head_1 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x, w):
phi = self.repr_net(x)
y0 = self.head_0(phi).squeeze()
y1 = self.head_1(phi).squeeze()
return y0, y1, phi
def mmd_loss(phi_t, phi_c, sigma=1.0):
"""RBF kernel MMD"""
def rbf_kernel(x, y):
diff = x.unsqueeze(1) - y.unsqueeze(0)
return torch.exp(-diff.pow(2).sum(-1) / (2 * sigma**2))
k_tt = rbf_kernel(phi_t, phi_t).mean()
k_cc = rbf_kernel(phi_c, phi_c).mean()
k_tc = rbf_kernel(phi_t, phi_c).mean()
return k_tt + k_cc - 2 * k_tc
# Training
def train_cfr(model, X, W, Y, lambda_mmd=1.0):
optimizer = torch.optim.Adam(model.parameters())
y0, y1, phi = model(X, W)
y_pred = W * y1 + (1 - W) * y0
# Factual loss
loss_factual = ((Y - y_pred) ** 2).mean()
# MMD regularization
phi_t = phi[W == 1]
phi_c = phi[W == 0]
loss_mmd = mmd_loss(phi_t, phi_c)
loss = loss_factual + lambda_mmd * loss_mmd
optimizer.zero_grad()
loss.backward()
optimizer.step()
Extensions
DragonNet
CFR + propensity score head
Perfect Match
Nearest neighbor in representation space
SITE
Adds self-supervised representation learning
Related Concepts
- Representation Learning Overview - Unified view of representation learning methods
- CEVAE - VAE-based alternative
- BNN - Simple balancing method
- Selection Bias - The problem being addressed
- PEHE - Evaluation metric
Key Papers
- Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
- Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
- yaoSurveyCausalInference2021 - Section 3.5.3