Tae Hyun Kim (Lowell)

CEVAE (Causal Effect Variational Autoencoder)

Definition

A method that uses a VAE to infer latent confounders and estimate causal effects

Proposed by Louizos et al. (2017).


Graphical Model

Generative Model

    Z (Latent Confounder)
   /|\
  / | \
 ↓  ↓  ↓
 X  W  Y
Zp(Z)Xp(XZ)Wp(WZ)Yp(YW,Z)\begin{align} Z &\sim p(Z) \\ X &\sim p(X \mid Z) \\ W &\sim p(W \mid Z) \\ Y &\sim p(Y \mid W, Z) \end{align}

Inference Model

q(ZX,W,Y)p(ZX,W,Y)q(Z \mid X, W, Y) \approx p(Z \mid X, W, Y)

VAE Framework

Evidence Lower Bound (ELBO)

logp(X,W,Y)Eq(ZX,W,Y)[logp(X,W,YZ)]KL(q(ZX,W,Y)p(Z))\log p(X, W, Y) \geq E_{q(Z|X,W,Y)}[\log p(X, W, Y \mid Z)] - \text{KL}(q(Z \mid X, W, Y) \| p(Z))

Network Architecture

Encoder: q(Z | X, W, Y)
    (X, W, Y) → μ_z, σ_z → Z ~ N(μ_z, σ_z²)

Decoder:
    Z → p(X | Z): Reconstruction
    Z → p(W | Z): Treatment model
    (Z, W) → p(Y | Z, W): Outcome model

Causal Effect Estimation

CATE Estimation

τ^(x)=Eq(ZX=x)[Y^(Z,W=1)Y^(Z,W=0)]\hat{\tau}(x) = E_{q(Z|X=x)}[\hat{Y}(Z, W=1) - \hat{Y}(Z, W=0)]

Algorithm

  1. Sample ZZ for an observation XX: zq(ZX)z \sim q(Z \mid X)
  2. Predict treated/control outcomes: y^1=f(z,1)\hat{y}_1 = f(z, 1), y^0=f(z,0)\hat{y}_0 = f(z, 0)
  3. CATE: τ^=y^1y^0\hat{\tau} = \hat{y}_1 - \hat{y}_0

Assumptions

Handling Hidden Confounders

CEVAE attempts to satisfy ignorability by inferring the latent confounder ZZ:

W ⁣ ⁣ ⁣(Y(0),Y(1))ZW \perp\!\!\!\perp (Y(0), Y(1)) \mid Z

Caveats

  • No guarantee that ZZ actually captures all confounding
  • Bias is possible under model misspecification

Advantages and Disadvantages

Advantages

AdvantageDescription
Uncertainty quantificationSampling from the posterior
Hidden confounderAttempts to infer latent confounders
Generative modelModels the data-generating mechanism
FlexibilityHandles diverse data types

Disadvantages

DisadvantageDescription
Model assumptionsRequires assuming a graphical structure
IdentifiabilityNo guarantee of recovering ZZ
Training instabilityVAE training is difficult
Computational costComplex networks

Implementation

Python (PyTorch)

import torch
import torch.nn as nn
from torch.distributions import Normal, Bernoulli

class CEVAE(nn.Module):
    def __init__(self, x_dim, z_dim=32, hidden_dim=64):
        super().__init__()

        # Encoder q(Z | X, W, Y)
        self.encoder = nn.Sequential(
            nn.Linear(x_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.z_mean = nn.Linear(hidden_dim, z_dim)
        self.z_logvar = nn.Linear(hidden_dim, z_dim)

        # Decoder p(X | Z)
        self.decoder_x = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, x_dim)
        )

        # Treatment model p(W | Z)
        self.decoder_w = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # Outcome model p(Y | Z, W)
        self.decoder_y = nn.Sequential(
            nn.Linear(z_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, w, y):
        h = self.encoder(torch.cat([x, w.unsqueeze(-1), y.unsqueeze(-1)], dim=-1))
        return self.z_mean(h), self.z_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, w):
        x_recon = self.decoder_x(z)
        w_prob = self.decoder_w(z).squeeze()
        y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(-1)], dim=-1)).squeeze()
        return x_recon, w_prob, y_pred

    def forward(self, x, w, y):
        mu, logvar = self.encode(x, w, y)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, w), mu, logvar

    def estimate_cate(self, x):
        """CATE estimation at test time"""
        # Encode without Y (approximate)
        mu, _ = self.encode(x, torch.zeros(len(x)), torch.zeros(len(x)))
        z = mu  # Use mean

        y1 = self.decoder_y(torch.cat([z, torch.ones(len(x), 1)], dim=-1)).squeeze()
        y0 = self.decoder_y(torch.cat([z, torch.zeros(len(x), 1)], dim=-1)).squeeze()

        return y1 - y0

  • Representation Learning Overview - Unified view of representation learning methods
  • CFR - Distribution-matching-based alternative
  • Hidden Confounders - The problem CEVAE aims to solve
  • Deconfounder - A related latent-variable inference approach

Key Papers

  • Louizos, C., Shalit, U., Mooij, J. M., Sontag, D., Zemel, R., & Welling, M. (2017). Causal Effect Inference with Deep Latent-Variable Models. NeurIPS
  • yaoSurveyCausalInference2021 - Section 3.5.4

Local graph