CFR (Counterfactual Regression)
정의
IPM(Integral Probability Metric) 정규화로 균형 잡힌 표현을 학습하는 딥러닝 방법
Shalit et al. (2017)이 제안.
모델 구조
TARNet (Treatment-Agnostic Representation Network)
X → [Representation Network] → Φ(X)
│
┌────────────┼────────────┐
▼ │ ▼
┌──────────┐ │ ┌──────────┐
│ h₀(Φ(X)) │ │ │ h₁(Φ(X)) │
│ Control │ │ │ Treated │
└──────────┘ │ └──────────┘
│
[IPM Regularization]
손실 함수
IPM 선택
1. CFR-MMD (Maximum Mean Discrepancy)
RBF 커널:
2. CFR-Wasserstein
구현: Discriminator로 근사 (WGAN 스타일)
이론적 보장
Generalization Bound
Shalit et al. (2017):
- PEHE 오류가 factual 오류와 IPM으로 bounded
- IPM 최소화의 동기
Trade-off
선택이 중요.
장단점
장점
| 장점 | 설명 |
|---|---|
| 이론적 보장 | Generalization bound |
| End-to-end | 표현 + 예측 동시 학습 |
| 유연성 | 다양한 네트워크 구조 |
| 확장성 | 고차원 데이터 처리 |
단점
| 단점 | 설명 |
|---|---|
| 선택 | 하이퍼파라미터 민감 |
| 대규모 데이터 | 딥러닝 특성 |
| 불확실성 | 신뢰구간 제공 어려움 |
| IPM 계산 | 특히 Wasserstein |
구현
Python (PyTorch)
import torch
import torch.nn as nn
class CFRNet(nn.Module):
def __init__(self, input_dim, hidden_dim=100, repr_dim=50):
super().__init__()
# Representation network
self.repr_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, repr_dim)
)
# Outcome networks
self.head_0 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 1)
)
self.head_1 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x, w):
phi = self.repr_net(x)
y0 = self.head_0(phi).squeeze()
y1 = self.head_1(phi).squeeze()
return y0, y1, phi
def mmd_loss(phi_t, phi_c, sigma=1.0):
"""RBF kernel MMD"""
def rbf_kernel(x, y):
diff = x.unsqueeze(1) - y.unsqueeze(0)
return torch.exp(-diff.pow(2).sum(-1) / (2 * sigma**2))
k_tt = rbf_kernel(phi_t, phi_t).mean()
k_cc = rbf_kernel(phi_c, phi_c).mean()
k_tc = rbf_kernel(phi_t, phi_c).mean()
return k_tt + k_cc - 2 * k_tc
# Training
def train_cfr(model, X, W, Y, lambda_mmd=1.0):
optimizer = torch.optim.Adam(model.parameters())
y0, y1, phi = model(X, W)
y_pred = W * y1 + (1 - W) * y0
# Factual loss
loss_factual = ((Y - y_pred) ** 2).mean()
# MMD regularization
phi_t = phi[W == 1]
phi_c = phi[W == 0]
loss_mmd = mmd_loss(phi_t, phi_c)
loss = loss_factual + lambda_mmd * loss_mmd
optimizer.zero_grad()
loss.backward()
optimizer.step()
확장
DragonNet
CFR + Propensity score head
Perfect Match
Nearest neighbor in representation space
SITE
Self-supervised 표현 학습 추가
관련 개념
- Representation Learning Overview - 표현 학습 방법 통합
- CEVAE - VAE 기반 대안
- BNN - 간단한 균형 방법
- Selection Bias - 해결 대상
- PEHE - 평가 지표
참고 논문
- Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
- Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
- yaoSurveyCausalInference2021 - Section 3.5.3