Representation Learning Overview
Overview
Methods for learning a representation that is independent of treatment while remaining useful for outcome prediction
Mermaid source (click to expand)
> flowchart LR > X[Covariates X] --> Phi[Representation Φ(X)] > Phi --> H1[h₁: Y(1) prediction] > Phi --> H0[h₀: Y(0) prediction] > > subgraph "Regularization" > Phi --> D[Distribution Matching] > end >
Core Idea
Domain Adaptation Perspective
Treat the treated and control groups as two different domains
- : prediction loss
- : distributional mismatch between treated/control representations
- : balance-prediction trade-off
Goals
- Balance:
- Predictive power: is useful for predicting
Main Methods
1. CFR (Counterfactual Regression)
Distribution matching via an IPM (Integral Probability Metric)
Details: CFR
2. CEVAE (Causal Effect VAE)
Infer latent confounders with a VAE
Details: CEVAE
3. BNN (Balancing Neural Network)
Shared representation + separate prediction heads
Details: BNN
4. GANITE
Generate counterfactuals with a GAN
- Generator: produces counterfactual outcomes
- Discriminator: distinguishes real from generated
Details: GANITE
Measuring Distributional Mismatch (IPM)
MMD (Maximum Mean Discrepancy)
- : Reproducing Kernel Hilbert Space
- Requires a kernel choice (RBF is common)
Wasserstein Distance
- Requires a Lipschitz constraint
- Implemented via a gradient penalty
Comparison of Methods
| Method | Approach | Characteristics | Suitable Setting |
|---|---|---|---|
| CFR | IPM regularization | Simple, scalable | Default choice |
| CEVAE | VAE + latent variables | Uncertainty, generative model | Complex DGP |
| BNN | Moment matching | Simple balancing | Quick experiments |
| GANITE | GAN | Direct ITE estimation | When personalization is needed |
Network Architecture
General Structure
Input X
│
▼
┌─────────────┐
│ Representation │
│ Network │
│ Φ(X) │
└─────────────┘
│
├─────────────┐
▼ ▼
┌─────────┐ ┌─────────┐
│ h₀(Φ) │ │ h₁(Φ) │
│ Control │ │ Treated │
│ Head │ │ Head │
└─────────┘ └─────────┘
Loss Function
Advantages and Disadvantages
Advantages
| Advantage | Description |
|---|---|
| End-to-end | Optimizes the entire pipeline |
| High-dimensional | Can handle images, text, etc. |
| Flexibility | Applicable across diverse network architectures |
| Automatic features | No manual feature engineering required |
Disadvantages
| Disadvantage | Description |
|---|---|
| Requires large data | Inherent to deep learning |
| Black box | Hard to interpret |
| Hyperparameters | Sensitive to choices such as |
| Limited theory | Limited asymptotic guarantees |
Implementation
Python (PyTorch)
import torch
import torch.nn as nn
class CFRNet(nn.Module):
def __init__(self, input_dim, hidden_dim, repr_dim):
super().__init__()
self.representation = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, repr_dim)
)
self.head_0 = nn.Linear(repr_dim, 1)
self.head_1 = nn.Linear(repr_dim, 1)
def forward(self, x, w):
phi = self.representation(x)
y0 = self.head_0(phi)
y1 = self.head_1(phi)
y = w * y1 + (1 - w) * y0
return y, y0, y1, phi
Related Concepts
- CFR - Counterfactual Regression
- CEVAE - Causal Effect VAE
- BNN - Balancing Neural Network
- GANITE - GAN for ITE
- Selection Bias - the problem being addressed
- HTE - the estimation target
Key Papers
- Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
- Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
- yaoSurveyCausalInference2021 - Section 3.5