CEVAE (Causal Effect Variational Autoencoder)
정의
VAE를 활용하여 잠재 교란변수를 추론하고 인과 효과를 추정하는 방법
Louizos et al. (2017)이 제안.
그래프 모델
생성 모델
Z (Latent Confounder)
/|\
/ | \
↓ ↓ ↓
X W Y
추론 모델
VAE 프레임워크
Evidence Lower Bound (ELBO)
네트워크 구조
Encoder: q(Z | X, W, Y)
(X, W, Y) → μ_z, σ_z → Z ~ N(μ_z, σ_z²)
Decoder:
Z → p(X | Z): Reconstruction
Z → p(W | Z): Treatment model
(Z, W) → p(Y | Z, W): Outcome model
인과 효과 추정
CATE 추정
알고리즘
- 관측 에 대해 샘플링:
- 처치/대조 결과 예측: ,
- CATE:
가정
Hidden Confounder 처리
CEVAE는 잠재 교란변수 를 추론하여 Ignorability 만족 시도:
주의
- 가 실제로 모든 교란을 포착하는지 보장 없음
- 모델 오특정 시 편향 가능
장단점
장점
| 장점 | 설명 |
|---|---|
| 불확실성 정량화 | Posterior에서 샘플링 |
| Hidden confounder | 잠재 교란변수 추론 시도 |
| 생성 모델 | 데이터 생성 메커니즘 모델링 |
| 유연성 | 다양한 데이터 타입 처리 |
단점
| 단점 | 설명 |
|---|---|
| 모델 가정 | 그래프 구조 가정 필요 |
| 식별 가능성 | 복구 보장 없음 |
| 학습 불안정 | VAE 학습 어려움 |
| 계산 비용 | 복잡한 네트워크 |
구현
Python (PyTorch)
import torch
import torch.nn as nn
from torch.distributions import Normal, Bernoulli
class CEVAE(nn.Module):
def __init__(self, x_dim, z_dim=32, hidden_dim=64):
super().__init__()
# Encoder q(Z | X, W, Y)
self.encoder = nn.Sequential(
nn.Linear(x_dim + 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.z_mean = nn.Linear(hidden_dim, z_dim)
self.z_logvar = nn.Linear(hidden_dim, z_dim)
# Decoder p(X | Z)
self.decoder_x = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, x_dim)
)
# Treatment model p(W | Z)
self.decoder_w = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# Outcome model p(Y | Z, W)
self.decoder_y = nn.Sequential(
nn.Linear(z_dim + 1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def encode(self, x, w, y):
h = self.encoder(torch.cat([x, w.unsqueeze(-1), y.unsqueeze(-1)], dim=-1))
return self.z_mean(h), self.z_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, w):
x_recon = self.decoder_x(z)
w_prob = self.decoder_w(z).squeeze()
y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(-1)], dim=-1)).squeeze()
return x_recon, w_prob, y_pred
def forward(self, x, w, y):
mu, logvar = self.encode(x, w, y)
z = self.reparameterize(mu, logvar)
return self.decode(z, w), mu, logvar
def estimate_cate(self, x):
"""CATE estimation at test time"""
# Encode without Y (approximate)
mu, _ = self.encode(x, torch.zeros(len(x)), torch.zeros(len(x)))
z = mu # Use mean
y1 = self.decoder_y(torch.cat([z, torch.ones(len(x), 1)], dim=-1)).squeeze()
y0 = self.decoder_y(torch.cat([z, torch.zeros(len(x), 1)], dim=-1)).squeeze()
return y1 - y0
관련 개념
- Representation Learning Overview - 표현 학습 방법 통합
- CFR - 분포 매칭 기반 대안
- Hidden Confounders - CEVAE가 해결하려는 문제
- Deconfounder - 관련 잠재 변수 추론
참고 논문
- Louizos, C., Shalit, U., Mooij, J. M., Sontag, D., Zemel, R., & Welling, M. (2017). Causal Effect Inference with Deep Latent-Variable Models. NeurIPS
- yaoSurveyCausalInference2021 - Section 3.5.4