Tae Hyun Kim (Lowell)

CEVAE (Causal Effect Variational Autoencoder)

정의

VAE를 활용하여 잠재 교란변수를 추론하고 인과 효과를 추정하는 방법

Louizos et al. (2017)이 제안.


그래프 모델

생성 모델

    Z (Latent Confounder)
   /|\
  / | \
 ↓  ↓  ↓
 X  W  Y
Zp(Z)Xp(XZ)Wp(WZ)Yp(YW,Z)\begin{align} Z &\sim p(Z) \\ X &\sim p(X \mid Z) \\ W &\sim p(W \mid Z) \\ Y &\sim p(Y \mid W, Z) \end{align}

추론 모델

q(ZX,W,Y)p(ZX,W,Y)q(Z \mid X, W, Y) \approx p(Z \mid X, W, Y)

VAE 프레임워크

Evidence Lower Bound (ELBO)

logp(X,W,Y)Eq(ZX,W,Y)[logp(X,W,YZ)]KL(q(ZX,W,Y)p(Z))\log p(X, W, Y) \geq E_{q(Z|X,W,Y)}[\log p(X, W, Y \mid Z)] - \text{KL}(q(Z \mid X, W, Y) \| p(Z))

네트워크 구조

Encoder: q(Z | X, W, Y)
    (X, W, Y) → μ_z, σ_z → Z ~ N(μ_z, σ_z²)

Decoder:
    Z → p(X | Z): Reconstruction
    Z → p(W | Z): Treatment model
    (Z, W) → p(Y | Z, W): Outcome model

인과 효과 추정

CATE 추정

τ^(x)=Eq(ZX=x)[Y^(Z,W=1)Y^(Z,W=0)]\hat{\tau}(x) = E_{q(Z|X=x)}[\hat{Y}(Z, W=1) - \hat{Y}(Z, W=0)]

알고리즘

  1. 관측 XX에 대해 ZZ 샘플링: zq(ZX)z \sim q(Z \mid X)
  2. 처치/대조 결과 예측: y^1=f(z,1)\hat{y}_1 = f(z, 1), y^0=f(z,0)\hat{y}_0 = f(z, 0)
  3. CATE: τ^=y^1y^0\hat{\tau} = \hat{y}_1 - \hat{y}_0

가정

Hidden Confounder 처리

CEVAE는 잠재 교란변수 ZZ추론하여 Ignorability 만족 시도:

W ⁣ ⁣ ⁣(Y(0),Y(1))ZW \perp\!\!\!\perp (Y(0), Y(1)) \mid Z

주의

  • ZZ가 실제로 모든 교란을 포착하는지 보장 없음
  • 모델 오특정 시 편향 가능

장단점

장점

장점설명
불확실성 정량화Posterior에서 샘플링
Hidden confounder잠재 교란변수 추론 시도
생성 모델데이터 생성 메커니즘 모델링
유연성다양한 데이터 타입 처리

단점

단점설명
모델 가정그래프 구조 가정 필요
식별 가능성ZZ 복구 보장 없음
학습 불안정VAE 학습 어려움
계산 비용복잡한 네트워크

구현

Python (PyTorch)

import torch
import torch.nn as nn
from torch.distributions import Normal, Bernoulli

class CEVAE(nn.Module):
    def __init__(self, x_dim, z_dim=32, hidden_dim=64):
        super().__init__()

        # Encoder q(Z | X, W, Y)
        self.encoder = nn.Sequential(
            nn.Linear(x_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.z_mean = nn.Linear(hidden_dim, z_dim)
        self.z_logvar = nn.Linear(hidden_dim, z_dim)

        # Decoder p(X | Z)
        self.decoder_x = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, x_dim)
        )

        # Treatment model p(W | Z)
        self.decoder_w = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # Outcome model p(Y | Z, W)
        self.decoder_y = nn.Sequential(
            nn.Linear(z_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, w, y):
        h = self.encoder(torch.cat([x, w.unsqueeze(-1), y.unsqueeze(-1)], dim=-1))
        return self.z_mean(h), self.z_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, w):
        x_recon = self.decoder_x(z)
        w_prob = self.decoder_w(z).squeeze()
        y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(-1)], dim=-1)).squeeze()
        return x_recon, w_prob, y_pred

    def forward(self, x, w, y):
        mu, logvar = self.encode(x, w, y)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, w), mu, logvar

    def estimate_cate(self, x):
        """CATE estimation at test time"""
        # Encode without Y (approximate)
        mu, _ = self.encode(x, torch.zeros(len(x)), torch.zeros(len(x)))
        z = mu  # Use mean

        y1 = self.decoder_y(torch.cat([z, torch.ones(len(x), 1)], dim=-1)).squeeze()
        y0 = self.decoder_y(torch.cat([z, torch.zeros(len(x), 1)], dim=-1)).squeeze()

        return y1 - y0

관련 개념

  • Representation Learning Overview - 표현 학습 방법 통합
  • CFR - 분포 매칭 기반 대안
  • Hidden Confounders - CEVAE가 해결하려는 문제
  • Deconfounder - 관련 잠재 변수 추론

참고 논문

  • Louizos, C., Shalit, U., Mooij, J. M., Sontag, D., Zemel, R., & Welling, M. (2017). Causal Effect Inference with Deep Latent-Variable Models. NeurIPS
  • yaoSurveyCausalInference2021 - Section 3.5.4

연결 그래프