Tae Hyun Kim (Lowell)

Representation Learning Overview

Overview

Methods for learning a representation that is independent of treatment while remaining useful for outcome prediction

Representation Learning Overview

Mermaid source (click to expand)
> flowchart LR
>     X[Covariates X] --> Phi[Representation Φ(X)]
>     Phi --> H1[h₁: Y(1) prediction]
>     Phi --> H0[h₀: Y(0) prediction]
> 
>     subgraph "Regularization"
>         Phi --> D[Distribution Matching]
>     end
>

Core Idea

Domain Adaptation Perspective

Treat the treated and control groups as two different domains

minΦ,hLpred+λdisc(Φ)\min_{\Phi, h} \mathcal{L}_{pred} + \lambda \cdot \text{disc}(\Phi)
  • Lpred\mathcal{L}_{pred}: prediction loss
  • disc(Φ)\text{disc}(\Phi): distributional mismatch between treated/control representations
  • λ\lambda: balance-prediction trade-off

Goals

  1. Balance: P(Φ(X)W=1)P(Φ(X)W=0)P(\Phi(X) \mid W=1) \approx P(\Phi(X) \mid W=0)
  2. Predictive power: Φ(X)\Phi(X) is useful for predicting YY

Main Methods

1. CFR (Counterfactual Regression)

Distribution matching via an IPM (Integral Probability Metric)

L=Lfactual+λIPM(PΦT,PΦC)\mathcal{L} = \mathcal{L}_{factual} + \lambda \cdot \text{IPM}(P^T_\Phi, P^C_\Phi)

Details: CFR

2. CEVAE (Causal Effect VAE)

Infer latent confounders with a VAE

ZX,ZW,Z,WYZ \to X, \quad Z \to W, \quad Z, W \to Y

Details: CEVAE

3. BNN (Balancing Neural Network)

Shared representation + separate prediction heads

L=L0+L1+λE[Φ(X)W=1]E[Φ(X)W=0]2\mathcal{L} = \mathcal{L}_0 + \mathcal{L}_1 + \lambda \|E[\Phi(X) \mid W=1] - E[\Phi(X) \mid W=0]\|^2

Details: BNN

4. GANITE

Generate counterfactuals with a GAN

  • Generator: produces counterfactual outcomes
  • Discriminator: distinguishes real from generated

Details: GANITE


Measuring Distributional Mismatch (IPM)

MMD (Maximum Mean Discrepancy)

MMD2=E[ϕ(X)W=1]E[ϕ(X)W=0]H2\text{MMD}^2 = \|E[\phi(X) \mid W=1] - E[\phi(X) \mid W=0]\|^2_{\mathcal{H}}
  • H\mathcal{H}: Reproducing Kernel Hilbert Space
  • Requires a kernel choice (RBF is common)

Wasserstein Distance

W1(P,Q)=supfL1EP[f]EQ[f]W_1(P, Q) = \sup_{\|f\|_L \leq 1} |E_P[f] - E_Q[f]|
  • Requires a Lipschitz constraint
  • Implemented via a gradient penalty

Comparison of Methods

MethodApproachCharacteristicsSuitable Setting
CFRIPM regularizationSimple, scalableDefault choice
CEVAEVAE + latent variablesUncertainty, generative modelComplex DGP
BNNMoment matchingSimple balancingQuick experiments
GANITEGANDirect ITE estimationWhen personalization is needed

Network Architecture

General Structure

Input X


┌─────────────┐
│ Representation │
│    Network    │
│    Φ(X)       │
└─────────────┘

    ├─────────────┐
    ▼             ▼
┌─────────┐  ┌─────────┐
│ h₀(Φ)  │  │ h₁(Φ)  │
│ Control │  │ Treated │
│ Head    │  │ Head    │
└─────────┘  └─────────┘

Loss Function

L=i:Wi=0(Yih0(Φ(Xi)))2+i:Wi=1(Yih1(Φ(Xi)))2Factual loss+λdisc(Φ)Balance\mathcal{L} = \underbrace{\sum_{i: W_i=0} (Y_i - h_0(\Phi(X_i)))^2 + \sum_{i: W_i=1} (Y_i - h_1(\Phi(X_i)))^2}_{\text{Factual loss}} + \lambda \cdot \underbrace{\text{disc}(\Phi)}_{\text{Balance}}

Advantages and Disadvantages

Advantages

AdvantageDescription
End-to-endOptimizes the entire pipeline
High-dimensionalCan handle images, text, etc.
FlexibilityApplicable across diverse network architectures
Automatic featuresNo manual feature engineering required

Disadvantages

DisadvantageDescription
Requires large dataInherent to deep learning
Black boxHard to interpret
HyperparametersSensitive to choices such as λ\lambda
Limited theoryLimited asymptotic guarantees

Implementation

Python (PyTorch)

import torch
import torch.nn as nn

class CFRNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, repr_dim):
        super().__init__()
        self.representation = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, repr_dim)
        )
        self.head_0 = nn.Linear(repr_dim, 1)
        self.head_1 = nn.Linear(repr_dim, 1)

    def forward(self, x, w):
        phi = self.representation(x)
        y0 = self.head_0(phi)
        y1 = self.head_1(phi)
        y = w * y1 + (1 - w) * y0
        return y, y0, y1, phi

  • CFR - Counterfactual Regression
  • CEVAE - Causal Effect VAE
  • BNN - Balancing Neural Network
  • GANITE - GAN for ITE
  • Selection Bias - the problem being addressed
  • HTE - the estimation target

Key Papers

  • Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
  • Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
  • yaoSurveyCausalInference2021 - Section 3.5

Local graph