Tae Hyun Kim (Lowell)

CFR (Counterfactual Regression)

Definition

A deep learning method that learns balanced representations via IPM (Integral Probability Metric) regularization

L=Lfactual+λIPM(PΦT,PΦC)\mathcal{L} = \mathcal{L}_{factual} + \lambda \cdot \text{IPM}(P^T_\Phi, P^C_\Phi)

Proposed by Shalit et al. (2017).


Model Architecture

TARNet (Treatment-Agnostic Representation Network)

X → [Representation Network] → Φ(X)

                     ┌────────────┼────────────┐
                     ▼            │            ▼
              ┌──────────┐        │     ┌──────────┐
              │ h₀(Φ(X)) │        │     │ h₁(Φ(X)) │
              │ Control  │        │     │ Treated  │
              └──────────┘        │     └──────────┘

                     [IPM Regularization]

Loss Function

L=1ni(YihWi(Φ(Xi)))2+λIPM(P^ΦT,P^ΦC)\mathcal{L} = \frac{1}{n}\sum_i (Y_i - h_{W_i}(\Phi(X_i)))^2 + \lambda \cdot \text{IPM}(\hat{P}^T_\Phi, \hat{P}^C_\Phi)

IPM Choices

1. CFR-MMD (Maximum Mean Discrepancy)

MMD2=1nTi:Wi=1k(Φ(Xi),)1nCi:Wi=0k(Φ(Xi),)H2\text{MMD}^2 = \left\|\frac{1}{n_T}\sum_{i: W_i=1} k(\Phi(X_i), \cdot) - \frac{1}{n_C}\sum_{i: W_i=0} k(\Phi(X_i), \cdot)\right\|^2_{\mathcal{H}}

RBF kernel:

k(x,x)=exp(xx22σ2)k(x, x') = \exp\left(-\frac{\|x - x'\|^2}{2\sigma^2}\right)

2. CFR-Wasserstein

W1(PT,PC)=supfL1EPT[f(Φ)]EPC[f(Φ)]W_1(P^T, P^C) = \sup_{\|f\|_L \leq 1} \left|E_{P^T}[f(\Phi)] - E_{P^C}[f(\Phi)]\right|

Implementation: Approximated with a discriminator (WGAN style)


Theoretical Guarantees

Generalization Bound

Shalit et al. (2017):

ϵPEHEϵfactual+αIPM+complexity terms\epsilon_{PEHE} \leq \epsilon_{factual} + \alpha \cdot \text{IPM} + \text{complexity terms}
  • The PEHE error is bounded by the factual error and the IPM
  • Motivates IPM minimization

Trade-off

λ    Balance,Prediction\lambda \uparrow \implies \text{Balance} \uparrow, \quad \text{Prediction} \downarrow

The choice of λ\lambda is critical.


Advantages and Disadvantages

Advantages

AdvantageDescription
Theoretical guaranteeGeneralization bound
End-to-endJoint learning of representation + prediction
FlexibilityVarious network architectures
ScalabilityHandles high-dimensional data

Disadvantages

DisadvantageDescription
Choice of λ\lambdaSensitive to the hyperparameter
Large-scale dataInherent to deep learning
UncertaintyHard to provide confidence intervals
IPM computationEspecially Wasserstein

Implementation

Python (PyTorch)

import torch
import torch.nn as nn

class CFRNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=100, repr_dim=50):
        super().__init__()
        # Representation network
        self.repr_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, repr_dim)
        )
        # Outcome networks
        self.head_0 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1)
        )
        self.head_1 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, w):
        phi = self.repr_net(x)
        y0 = self.head_0(phi).squeeze()
        y1 = self.head_1(phi).squeeze()
        return y0, y1, phi

def mmd_loss(phi_t, phi_c, sigma=1.0):
    """RBF kernel MMD"""
    def rbf_kernel(x, y):
        diff = x.unsqueeze(1) - y.unsqueeze(0)
        return torch.exp(-diff.pow(2).sum(-1) / (2 * sigma**2))

    k_tt = rbf_kernel(phi_t, phi_t).mean()
    k_cc = rbf_kernel(phi_c, phi_c).mean()
    k_tc = rbf_kernel(phi_t, phi_c).mean()

    return k_tt + k_cc - 2 * k_tc

# Training
def train_cfr(model, X, W, Y, lambda_mmd=1.0):
    optimizer = torch.optim.Adam(model.parameters())

    y0, y1, phi = model(X, W)
    y_pred = W * y1 + (1 - W) * y0

    # Factual loss
    loss_factual = ((Y - y_pred) ** 2).mean()

    # MMD regularization
    phi_t = phi[W == 1]
    phi_c = phi[W == 0]
    loss_mmd = mmd_loss(phi_t, phi_c)

    loss = loss_factual + lambda_mmd * loss_mmd

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Extensions

DragonNet

CFR + propensity score head

Perfect Match

Nearest neighbor in representation space

SITE

Adds self-supervised representation learning


  • Representation Learning Overview - Unified view of representation learning methods
  • CEVAE - VAE-based alternative
  • BNN - Simple balancing method
  • Selection Bias - The problem being addressed
  • PEHE - Evaluation metric

Key Papers

  • Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
  • Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
  • yaoSurveyCausalInference2021 - Section 3.5.3

Local graph