CEVAE (Causal Effect Variational Autoencoder)
Definition
A method that uses a VAE to infer latent confounders and estimate causal effects
Proposed by Louizos et al. (2017).
Graphical Model
Generative Model
Z (Latent Confounder)
/|\
/ | \
↓ ↓ ↓
X W Y
Inference Model
VAE Framework
Evidence Lower Bound (ELBO)
Network Architecture
Encoder: q(Z | X, W, Y)
(X, W, Y) → μ_z, σ_z → Z ~ N(μ_z, σ_z²)
Decoder:
Z → p(X | Z): Reconstruction
Z → p(W | Z): Treatment model
(Z, W) → p(Y | Z, W): Outcome model
Causal Effect Estimation
CATE Estimation
Algorithm
- Sample for an observation :
- Predict treated/control outcomes: ,
- CATE:
Assumptions
Handling Hidden Confounders
CEVAE attempts to satisfy ignorability by inferring the latent confounder :
Caveats
- No guarantee that actually captures all confounding
- Bias is possible under model misspecification
Advantages and Disadvantages
Advantages
| Advantage | Description |
|---|---|
| Uncertainty quantification | Sampling from the posterior |
| Hidden confounder | Attempts to infer latent confounders |
| Generative model | Models the data-generating mechanism |
| Flexibility | Handles diverse data types |
Disadvantages
| Disadvantage | Description |
|---|---|
| Model assumptions | Requires assuming a graphical structure |
| Identifiability | No guarantee of recovering |
| Training instability | VAE training is difficult |
| Computational cost | Complex networks |
Implementation
Python (PyTorch)
import torch
import torch.nn as nn
from torch.distributions import Normal, Bernoulli
class CEVAE(nn.Module):
def __init__(self, x_dim, z_dim=32, hidden_dim=64):
super().__init__()
# Encoder q(Z | X, W, Y)
self.encoder = nn.Sequential(
nn.Linear(x_dim + 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.z_mean = nn.Linear(hidden_dim, z_dim)
self.z_logvar = nn.Linear(hidden_dim, z_dim)
# Decoder p(X | Z)
self.decoder_x = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, x_dim)
)
# Treatment model p(W | Z)
self.decoder_w = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# Outcome model p(Y | Z, W)
self.decoder_y = nn.Sequential(
nn.Linear(z_dim + 1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def encode(self, x, w, y):
h = self.encoder(torch.cat([x, w.unsqueeze(-1), y.unsqueeze(-1)], dim=-1))
return self.z_mean(h), self.z_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, w):
x_recon = self.decoder_x(z)
w_prob = self.decoder_w(z).squeeze()
y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(-1)], dim=-1)).squeeze()
return x_recon, w_prob, y_pred
def forward(self, x, w, y):
mu, logvar = self.encode(x, w, y)
z = self.reparameterize(mu, logvar)
return self.decode(z, w), mu, logvar
def estimate_cate(self, x):
"""CATE estimation at test time"""
# Encode without Y (approximate)
mu, _ = self.encode(x, torch.zeros(len(x)), torch.zeros(len(x)))
z = mu # Use mean
y1 = self.decoder_y(torch.cat([z, torch.ones(len(x), 1)], dim=-1)).squeeze()
y0 = self.decoder_y(torch.cat([z, torch.zeros(len(x), 1)], dim=-1)).squeeze()
return y1 - y0
Related Concepts
- Representation Learning Overview - Unified view of representation learning methods
- CFR - Distribution-matching-based alternative
- Hidden Confounders - The problem CEVAE aims to solve
- Deconfounder - A related latent-variable inference approach
Key Papers
- Louizos, C., Shalit, U., Mooij, J. M., Sontag, D., Zemel, R., & Welling, M. (2017). Causal Effect Inference with Deep Latent-Variable Models. NeurIPS
- yaoSurveyCausalInference2021 - Section 3.5.4