Tae Hyun Kim (Lowell)

CFR (Counterfactual Regression)

정의

IPM(Integral Probability Metric) 정규화로 균형 잡힌 표현을 학습하는 딥러닝 방법

L=Lfactual+λIPM(PΦT,PΦC)\mathcal{L} = \mathcal{L}_{factual} + \lambda \cdot \text{IPM}(P^T_\Phi, P^C_\Phi)

Shalit et al. (2017)이 제안.


모델 구조

TARNet (Treatment-Agnostic Representation Network)

X → [Representation Network] → Φ(X)

                     ┌────────────┼────────────┐
                     ▼            │            ▼
              ┌──────────┐        │     ┌──────────┐
              │ h₀(Φ(X)) │        │     │ h₁(Φ(X)) │
              │ Control  │        │     │ Treated  │
              └──────────┘        │     └──────────┘

                     [IPM Regularization]

손실 함수

L=1ni(YihWi(Φ(Xi)))2+λIPM(P^ΦT,P^ΦC)\mathcal{L} = \frac{1}{n}\sum_i (Y_i - h_{W_i}(\Phi(X_i)))^2 + \lambda \cdot \text{IPM}(\hat{P}^T_\Phi, \hat{P}^C_\Phi)

IPM 선택

1. CFR-MMD (Maximum Mean Discrepancy)

MMD2=1nTi:Wi=1k(Φ(Xi),)1nCi:Wi=0k(Φ(Xi),)H2\text{MMD}^2 = \left\|\frac{1}{n_T}\sum_{i: W_i=1} k(\Phi(X_i), \cdot) - \frac{1}{n_C}\sum_{i: W_i=0} k(\Phi(X_i), \cdot)\right\|^2_{\mathcal{H}}

RBF 커널:

k(x,x)=exp(xx22σ2)k(x, x') = \exp\left(-\frac{\|x - x'\|^2}{2\sigma^2}\right)

2. CFR-Wasserstein

W1(PT,PC)=supfL1EPT[f(Φ)]EPC[f(Φ)]W_1(P^T, P^C) = \sup_{\|f\|_L \leq 1} \left|E_{P^T}[f(\Phi)] - E_{P^C}[f(\Phi)]\right|

구현: Discriminator로 근사 (WGAN 스타일)


이론적 보장

Generalization Bound

Shalit et al. (2017):

ϵPEHEϵfactual+αIPM+complexity terms\epsilon_{PEHE} \leq \epsilon_{factual} + \alpha \cdot \text{IPM} + \text{complexity terms}
  • PEHE 오류가 factual 오류와 IPM으로 bounded
  • IPM 최소화의 동기

Trade-off

λ    Balance,Prediction\lambda \uparrow \implies \text{Balance} \uparrow, \quad \text{Prediction} \downarrow

λ\lambda 선택이 중요.


장단점

장점

장점설명
이론적 보장Generalization bound
End-to-end표현 + 예측 동시 학습
유연성다양한 네트워크 구조
확장성고차원 데이터 처리

단점

단점설명
λ\lambda 선택하이퍼파라미터 민감
대규모 데이터딥러닝 특성
불확실성신뢰구간 제공 어려움
IPM 계산특히 Wasserstein

구현

Python (PyTorch)

import torch
import torch.nn as nn

class CFRNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=100, repr_dim=50):
        super().__init__()
        # Representation network
        self.repr_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, repr_dim)
        )
        # Outcome networks
        self.head_0 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1)
        )
        self.head_1 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, w):
        phi = self.repr_net(x)
        y0 = self.head_0(phi).squeeze()
        y1 = self.head_1(phi).squeeze()
        return y0, y1, phi

def mmd_loss(phi_t, phi_c, sigma=1.0):
    """RBF kernel MMD"""
    def rbf_kernel(x, y):
        diff = x.unsqueeze(1) - y.unsqueeze(0)
        return torch.exp(-diff.pow(2).sum(-1) / (2 * sigma**2))

    k_tt = rbf_kernel(phi_t, phi_t).mean()
    k_cc = rbf_kernel(phi_c, phi_c).mean()
    k_tc = rbf_kernel(phi_t, phi_c).mean()

    return k_tt + k_cc - 2 * k_tc

# Training
def train_cfr(model, X, W, Y, lambda_mmd=1.0):
    optimizer = torch.optim.Adam(model.parameters())

    y0, y1, phi = model(X, W)
    y_pred = W * y1 + (1 - W) * y0

    # Factual loss
    loss_factual = ((Y - y_pred) ** 2).mean()

    # MMD regularization
    phi_t = phi[W == 1]
    phi_c = phi[W == 0]
    loss_mmd = mmd_loss(phi_t, phi_c)

    loss = loss_factual + lambda_mmd * loss_mmd

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

확장

DragonNet

CFR + Propensity score head

Perfect Match

Nearest neighbor in representation space

SITE

Self-supervised 표현 학습 추가


관련 개념


참고 논문

  • Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
  • Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
  • yaoSurveyCausalInference2021 - Section 3.5.3

연결 그래프