Representation Learning Overview
개요
처치와 독립적이면서 결과 예측에 유용한 표현(representation)을 학습하는 방법
Mermaid source (click to expand)
> flowchart LR > X[Covariates X] --> Phi[Representation Φ(X)] > Phi --> H1[h₁: Y(1) prediction] > Phi --> H0[h₀: Y(0) prediction] > > subgraph "Regularization" > Phi --> D[Distribution Matching] > end >
핵심 아이디어
도메인 적응 관점
처치군/대조군을 서로 다른 도메인으로 간주
- : 예측 손실
- : 처치/대조 표현 분포 불일치
- : 균형-예측 트레이드오프
목표
- 균형:
- 예측력: 가 예측에 유용
주요 방법
1. CFR (Counterfactual Regression)
IPM(Integral Probability Metric)으로 분포 일치
자세한 내용: CFR
2. CEVAE (Causal Effect VAE)
VAE로 잠재 교란변수 추론
자세한 내용: CEVAE
3. BNN (Balancing Neural Network)
공유 표현 + 분리된 예측 헤드
자세한 내용: BNN
4. GANITE
GAN으로 Counterfactual 생성
- Generator: 반사실적 결과 생성
- Discriminator: 실제/생성 구별
자세한 내용: GANITE
분포 불일치 측정 (IPM)
MMD (Maximum Mean Discrepancy)
- : Reproducing Kernel Hilbert Space
- 커널 선택 필요 (RBF 일반적)
Wasserstein Distance
- Lipschitz 제약 필요
- Gradient penalty로 구현
방법 비교
| 방법 | 접근 | 특징 | 적합 상황 |
|---|---|---|---|
| CFR | IPM regularization | 간단, 확장성 | 기본 선택 |
| CEVAE | VAE + 잠재 변수 | 불확실성, 생성 모델 | 복잡한 DGP |
| BNN | 모멘트 매칭 | 간단한 균형 | 빠른 실험 |
| GANITE | GAN | ITE 직접 추정 | 개인화 필요 |
네트워크 구조
일반 구조
Input X
│
▼
┌─────────────┐
│ Representation │
│ Network │
│ Φ(X) │
└─────────────┘
│
├─────────────┐
▼ ▼
┌─────────┐ ┌─────────┐
│ h₀(Φ) │ │ h₁(Φ) │
│ Control │ │ Treated │
│ Head │ │ Head │
└─────────┘ └─────────┘
손실 함수
장단점
장점
| 장점 | 설명 |
|---|---|
| End-to-end | 전체 파이프라인 최적화 |
| 고차원 | 이미지, 텍스트 등 처리 가능 |
| 유연성 | 다양한 네트워크 구조 적용 |
| 자동 특성 | 수동 특성 공학 불필요 |
단점
| 단점 | 설명 |
|---|---|
| 대규모 데이터 필요 | 딥러닝 특성 |
| 블랙박스 | 해석 어려움 |
| 하이퍼파라미터 | 등 선택 민감 |
| 이론 부족 | 점근적 보장 제한적 |
구현
Python (PyTorch)
import torch
import torch.nn as nn
class CFRNet(nn.Module):
def __init__(self, input_dim, hidden_dim, repr_dim):
super().__init__()
self.representation = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, repr_dim)
)
self.head_0 = nn.Linear(repr_dim, 1)
self.head_1 = nn.Linear(repr_dim, 1)
def forward(self, x, w):
phi = self.representation(x)
y0 = self.head_0(phi)
y1 = self.head_1(phi)
y = w * y1 + (1 - w) * y0
return y, y0, y1, phi
관련 개념
- CFR - Counterfactual Regression
- CEVAE - Causal Effect VAE
- BNN - Balancing Neural Network
- GANITE - GAN for ITE
- Selection Bias - 해결 대상
- HTE - 추정 대상
참고 논문
- Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
- Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
- yaoSurveyCausalInference2021 - Section 3.5