Tae Hyun Kim (Lowell)

Representation Learning Overview

개요

처치와 독립적이면서 결과 예측에 유용한 표현(representation)을 학습하는 방법

Representation Learning Overview

Mermaid source (click to expand)
> flowchart LR
>     X[Covariates X] --> Phi[Representation Φ(X)]
>     Phi --> H1[h₁: Y(1) prediction]
>     Phi --> H0[h₀: Y(0) prediction]
> 
>     subgraph "Regularization"
>         Phi --> D[Distribution Matching]
>     end
>

핵심 아이디어

도메인 적응 관점

처치군/대조군을 서로 다른 도메인으로 간주

minΦ,hLpred+λdisc(Φ)\min_{\Phi, h} \mathcal{L}_{pred} + \lambda \cdot \text{disc}(\Phi)
  • Lpred\mathcal{L}_{pred}: 예측 손실
  • disc(Φ)\text{disc}(\Phi): 처치/대조 표현 분포 불일치
  • λ\lambda: 균형-예측 트레이드오프

목표

  1. 균형: P(Φ(X)W=1)P(Φ(X)W=0)P(\Phi(X) \mid W=1) \approx P(\Phi(X) \mid W=0)
  2. 예측력: Φ(X)\Phi(X)YY 예측에 유용

주요 방법

1. CFR (Counterfactual Regression)

IPM(Integral Probability Metric)으로 분포 일치

L=Lfactual+λIPM(PΦT,PΦC)\mathcal{L} = \mathcal{L}_{factual} + \lambda \cdot \text{IPM}(P^T_\Phi, P^C_\Phi)

자세한 내용: CFR

2. CEVAE (Causal Effect VAE)

VAE로 잠재 교란변수 추론

ZX,ZW,Z,WYZ \to X, \quad Z \to W, \quad Z, W \to Y

자세한 내용: CEVAE

3. BNN (Balancing Neural Network)

공유 표현 + 분리된 예측 헤드

L=L0+L1+λE[Φ(X)W=1]E[Φ(X)W=0]2\mathcal{L} = \mathcal{L}_0 + \mathcal{L}_1 + \lambda \|E[\Phi(X) \mid W=1] - E[\Phi(X) \mid W=0]\|^2

자세한 내용: BNN

4. GANITE

GAN으로 Counterfactual 생성

  • Generator: 반사실적 결과 생성
  • Discriminator: 실제/생성 구별

자세한 내용: GANITE


분포 불일치 측정 (IPM)

MMD (Maximum Mean Discrepancy)

MMD2=E[ϕ(X)W=1]E[ϕ(X)W=0]H2\text{MMD}^2 = \|E[\phi(X) \mid W=1] - E[\phi(X) \mid W=0]\|^2_{\mathcal{H}}
  • H\mathcal{H}: Reproducing Kernel Hilbert Space
  • 커널 선택 필요 (RBF 일반적)

Wasserstein Distance

W1(P,Q)=supfL1EP[f]EQ[f]W_1(P, Q) = \sup_{\|f\|_L \leq 1} |E_P[f] - E_Q[f]|
  • Lipschitz 제약 필요
  • Gradient penalty로 구현

방법 비교

방법접근특징적합 상황
CFRIPM regularization간단, 확장성기본 선택
CEVAEVAE + 잠재 변수불확실성, 생성 모델복잡한 DGP
BNN모멘트 매칭간단한 균형빠른 실험
GANITEGANITE 직접 추정개인화 필요

네트워크 구조

일반 구조

Input X


┌─────────────┐
│ Representation │
│    Network    │
│    Φ(X)       │
└─────────────┘

    ├─────────────┐
    ▼             ▼
┌─────────┐  ┌─────────┐
│ h₀(Φ)  │  │ h₁(Φ)  │
│ Control │  │ Treated │
│ Head    │  │ Head    │
└─────────┘  └─────────┘

손실 함수

L=i:Wi=0(Yih0(Φ(Xi)))2+i:Wi=1(Yih1(Φ(Xi)))2Factual loss+λdisc(Φ)Balance\mathcal{L} = \underbrace{\sum_{i: W_i=0} (Y_i - h_0(\Phi(X_i)))^2 + \sum_{i: W_i=1} (Y_i - h_1(\Phi(X_i)))^2}_{\text{Factual loss}} + \lambda \cdot \underbrace{\text{disc}(\Phi)}_{\text{Balance}}

장단점

장점

장점설명
End-to-end전체 파이프라인 최적화
고차원이미지, 텍스트 등 처리 가능
유연성다양한 네트워크 구조 적용
자동 특성수동 특성 공학 불필요

단점

단점설명
대규모 데이터 필요딥러닝 특성
블랙박스해석 어려움
하이퍼파라미터λ\lambda 등 선택 민감
이론 부족점근적 보장 제한적

구현

Python (PyTorch)

import torch
import torch.nn as nn

class CFRNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, repr_dim):
        super().__init__()
        self.representation = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, repr_dim)
        )
        self.head_0 = nn.Linear(repr_dim, 1)
        self.head_1 = nn.Linear(repr_dim, 1)

    def forward(self, x, w):
        phi = self.representation(x)
        y0 = self.head_0(phi)
        y1 = self.head_1(phi)
        y = w * y1 + (1 - w) * y0
        return y, y0, y1, phi

관련 개념

  • CFR - Counterfactual Regression
  • CEVAE - Causal Effect VAE
  • BNN - Balancing Neural Network
  • GANITE - GAN for ITE
  • Selection Bias - 해결 대상
  • HTE - 추정 대상

참고 논문

  • Johansson, F. D., Shalit, U., & Sontag, D. (2016). Learning representations for counterfactual inference. ICML
  • Shalit, U., Johansson, F. D., & Sontag, D. (2017). Estimating individual treatment effect: Generalization bounds and algorithms. ICML
  • yaoSurveyCausalInference2021 - Section 3.5

연결 그래프