Blog
GANで因果推論
はじめに
6月中旬からAI LabのAD Econチームでインターンをしている齋藤 (@moshumoshu1205)です. 普段はこちらのブログで記事を書いたりしています.
唐突ですが, 広告や教育, そして医療に至るまである介入が個々の被介入者に対して有する因果効果(個別的因果効果: Individual Treatment Effect)を予測できると嬉しいことがたくさんあります. 例えば, 本当に効果がある人を事前に特定してその人たちにのみ広告を打てるとしたら, 広告配信の費用対効果を改善できます. さらに, 望ましい治癒効果の望まれる人にのみ投薬や治療を施すことができたら, 治療効果の最大化が期待できるでしょう. これらの例にとどまらず多くの分野で, 介入の適切な個別化はKPIの改善に寄与すると考えられるためここ数年で盛んに研究されています.
本記事では, それらの研究の一つであるGAN (Generative Adversarial Nets)を活用して個別的因果効果を予測する手法を提案している研究[Yoon et al. (2018)]の紹介をしたいと思います.
定式化
ここでは, 因果効果予測の問題を簡潔に定式化します. 細かい仮定などは[Yoon et al. (2018)]のSection 2をご覧ください. また本記事では, 単に介入がbinaryである状況のみを記述しています(例えば, ある介入をするかしないかのみに興味がある場合). しかし, 本記事で紹介する手法は介入が多値である場合にも適用できます (介入が3種類以上存在し, それぞれの介入効果に興味がある場合).
\( \mathcal{X} \)を特徴量空間, \( \mathcal{Y} \)を出力空間とします. また介入空間を単に\( \mathcal{T} = \{0, 1\} \)としておきます. それぞれが \( \mathcal{X}, \mathcal{T}, \mathcal{Y} \) 上に値をとる確率変数を \( \boldsymbol{X}, T, \boldsymbol{Y} \)とした時, これらの確率変数が同時確率分布\( \mu \)に従うとしておきます (i.e., \( \left( \boldsymbol{X}, T, \boldsymbol{Y} \right) \sim \mu \)). ここで出力に関する確率変数\( \boldsymbol{Y} \)がベクトルである \( \boldsymbol{Y} = \left( Y^{(0)}, Y^{(1)} \right) \) ことには注意が必要です. これはRubin-Neyman causal modelに従ったnotationで, 介入有無のそれぞれに対応する目的変数(Potential outcomes)がそれぞれ独立に存在することを想定しています. このような想定を立てることで因果効果を記述したり扱いやすくするためのモデルがRubin-Neyman causal modelです. 詳しくは拙著のqiita記事をご覧ください.
さてこのモデル化のもとではあるデータ\( i \)に対するIndividual Treatment Effect (ITE)は次のように定義されます.
$$
\tau_i = Y_i^{(1)} – Y_i^{(0)}
$$
これは介入した時の目的変数と介入しなかった時の目的変数の差分を意味するので, 介入の因果効果の定義として妥当そうです. もしもこのITEの値が観測されているならば, 特徴量\( \boldsymbol{x} \)からITE \( \tau \)を予測するような教師あり学習を行うことで未知データに対するITEの予測値を比較的容易に手に入れることができるでしょう.
しかし事はそう簡単には運びません. なぜならば私たちはいかなるデータについてもそのITEの実現値を観測することはできないからです. これは私たちは各データに対してたかだか一つの介入しか施すことができないことに起因します. 実際に私たちが観測可能な目的変数\( Y^{obs} \)は, \( \left( Y^{(0)}, Y^{(1)} \right) \)と\( T \)を用いて次のように表されます.
$$
Y^{obs} = T Y^{(1)} + (1 – T) Y^{(0)}
$$
これにより, 介入した時(\( T = 1 \))はそれに対応する目的変数\( Y^{(1)} \)のみが, 介入しなかった時(\( T = 0 \))はそれに対応する目的変数\( Y^{(0)} \)のみがそれぞれ観測されるという事実が明確化されるかなと思います. よって私たちが用いることができる訓練データは, \( \mathcal{D} = \{ \boldsymbol{x}_i, t_i, y_i^{obs} \}_{i=1}^{n} \)のような形をしており, この情報のみを用いて\( \tau \)をよく予測するような予測器を作らねばならないというのがこの分野が直面する大きな困難でありかつ面白さでもあるわけです. ここで観測される側の目的変数\( Y^{(T)} \)をfactual outcome, 観測されない側の目的変数\( Y^{(1 – T)} \)をcounterfactual outcomeと呼んだりします.
真のITEが一切観測されない状況で未知データに対すてITEを予測する手法はいくつか存在します. 中でも本記事で紹介するGANITEは, 観測されないcounterfactual outcomeをGANのような考え方に基づき人工的に生成することで, 擬似的なITEを含む訓練データを生成し, それに基づきITEの予測器を学習するというアイデアに基づいています.
GAN
GANITEは, GANのアイデアを活用しているので先立って簡潔にGANの説明をしておきます.
しかしGANに関してはお詳しい方も多いと思うので, 本章は読み飛ばして頂いても構いません.
GANの目標は, ある確率分布\( p_{data} \)を出来るだけよく近似するようなgenerator \( \boldsymbol{G} \)と呼ばれるNetworkを学習することです. Generatorの入力はランダムノイズ\( \boldsymbol{z} \sim p_{\boldsymbol{z}} \)で \( p_{\boldsymbol{z}} \)には一様分布などが用いられます.
より良いgenerator \( \boldsymbol{G} \)を学習するために, GANはもう一つのdiscriminatorと呼ばれるNetwork \( \boldsymbol{D} \)を活用します. generatorはノイズから出来るだけ\( p_{data} \)のサンプルと見なせるような偽のデータを生成しようとしますが, discriminatorはgeneratorが生成したデータと真の分布\( p_{data} \)からサンプルされたデータを精度よく判別できるように学習されます. このように2つのNetworkを敵対的に学習することにより, discriminatorの判別精度が低い所で落ち着いたならばそれはすなわち, generatorが真の分布からのサンプルと遜色ないデータを生成できていることを意味しそうです.
GANは上記のような意図を反映させたmin-max構造を持つ損失関数をgeneratorとdiscriminatorに関して交互に最適化していくという手順を踏みます.
$$
\min_{\boldsymbol{G}} \max_{\boldsymbol{D}} E_{\boldsymbol{x} \sim p_{\text { data }}} \left[ \log \boldsymbol{D}(\boldsymbol{x}) \right]+ E_{\boldsymbol{z} \sim p_{\boldsymbol{z}}} \left[\log (1- \boldsymbol{D}( \boldsymbol{G}(\boldsymbol{z}))) \right]
$$
実際には観測されないcounterfactual outcomeの確率分布をGAN的な損失を最適化することで近似しようというのが[Yoon et al. (2018)]の基本的なアイデアになります.
GANITE
さてここからが本記事の主題になります. [Yoon et al. (2018)]における提案手法であるGenerative Adversarial Nets for inference of Individual Treatment Effects (GANITE)は, 大きく2つのパートから構成されます.
- 実際には観測されないcounterfactual outcomeを生成するパート
- もともと観測されるfactual outcomeと1. で生成されるcounterfactual outcomeを用いてITE予測器を学習するパート
以降この2つのパートをそれぞれ紹介し, そのあとで全体的なアルゴリズムを記述します.
Counterfactual outcomeの生成
このパートは, Counterfactual Generator \( \boldsymbol{G} \) とCounterfactual Discriminator \( \boldsymbol{D_G} \) 間の敵対的学習によって構成されます.
・Counterfactual generator \( \boldsymbol{G} \): 特徴量ベクトル, 介入, 観測される目的変数をそれぞれ\( \boldsymbol{X}, T, Y^{obs} \)とし, \( \boldsymbol{Z}_G \)を\( [-1, 1] \times [-1, 1] \)上に値をとる一様分布に従う確率変数とします. この時, これら4つの変数を入力とする関数を\( g: \mathcal{X} \times \mathcal{T} \times \mathcal{Y} \times [-1, 1]^2 \rightarrow \mathcal{Y} \times \mathcal{Y} \)とします. Counterfactual generator \( \boldsymbol{G} \)はこの関数\( g \)の出力によって定義される確率分布です.
$$
G \left( \boldsymbol{X}, T, Y^{obs} \right) = g \left( \boldsymbol{X}, T, Y^{obs}, \boldsymbol{Z}_G \right)
$$
Counterfactual generatorを得る上での目標は, 確率分布\( \boldsymbol{G} \)がpotential outcomesの周辺分布をよりよく近似するような関数\( g \)を学習することになります. もしそのような\( g \)を得ることができれば, \( \boldsymbol{G} \)によって得られるサンプルが実際には観測されないcounterfactual outcomeを含めた目的変数をよく代替してくれると考えることができそうです.
・Counterfactual discriminator \( \boldsymbol{D_G} \): Counterfactual generatorに従う確率変数を\( \tilde{\boldsymbol{Y}} = \left( \tilde{Y}^{(0)}, \tilde{Y}^{(1)} \right) \sim \boldsymbol{G} \)とします. また\( \tilde{\boldsymbol{Y}} \)のうち, 実際に実現した介入に対応する要素を\( Y^{obs} \)に入れ替えた確率変数を\( \bar{\boldsymbol{Y}} \)で表します (例えば, \( T = 1 \)ならば, \( \bar{\boldsymbol{Y}} = \left( \tilde{Y}^{(0)}, Y^{obs} \right) \)となります. ).
この時, Counterfactual discriminator \( \boldsymbol{D_G}: \mathcal{X} \times \mathcal{Y} \times \mathcal{Y} \rightarrow [0,1] \times [0, 1] \)は, \( \left( \boldsymbol{X}, \bar{\boldsymbol{Y}} \right) \)を入力とした時に, \( \bar{\boldsymbol{Y}} \)のどちらの要素がどれくらいの確率で実際に観測されたfactual outcome (\( Y^{obs} \))なのかを予測します. このCounterfactual discriminatorの予測精度が悪ければ悪いほど実際に観測されたoutcomeとの判別ができないようなcounterfactual outcomeをCounterfactual generatorが生成できていることを意味します.
最終的に\( \boldsymbol{G} \)と\( \boldsymbol{D_G} \)に関する敵対的な学習は次のように定式化されます.
$$
\min_{\boldsymbol{G}} \, \max_{\boldsymbol{D_G}} \; E_{( \boldsymbol{X}, T, Y^{obs} )} \left[ E_{\boldsymbol{Z}_G} \left[ \boldsymbol{T}’ \log \boldsymbol{D_G} \left( \boldsymbol{X}, \bar{\boldsymbol{Y}} \right) + ( \boldsymbol{1} – \boldsymbol{T} )’ \log \left( \boldsymbol{1} – \boldsymbol{D_G}\left( \boldsymbol{X}, \bar{\boldsymbol{Y}} \right) \right) \right] \right]
$$
ただし, \( \boldsymbol{T} \)は\( \boldsymbol{T} = (1 – T, T) \)で表されるベクトルです.
ここでの定式化に従ってCounterfactual generatorを得たら, それによって擬似的なITEを含むデータセット\( \tilde{D} = \{ \boldsymbol{x}_i, t_i, \tilde{y}_i \}_{i=1}^n \)を生成し, ITE予測器の学習に移ります.
ITE予測器の学習
ここでは, 前節のような方法で生成された擬似的なデータセット\( \tilde{D} = \{ \boldsymbol{x}_i, t_i, \tilde{y}_i \}_{i=1}^n \)を用いてITE予測器を生成する方法を紹介します. [Yoon et al. (2018)]はこのITE予測器の学習にもITE generatorとITE discriminatorによる敵対的学習を行う方法を採用しています.
・ITE generator \( \boldsymbol{I} \): 特徴量ベクトル\( \boldsymbol{X} \)と, \( [-1, 1] \times [-1, 1] \)上に値をとる一様分布に従う確率変数\( \boldsymbol{Z}_I \)を入力とする関数を\( h: \mathcal{X} \times [-1, 1]^2 \rightarrow \mathcal{Y} \times \mathcal{Y} \)とします. ITE generator \( \boldsymbol{I} \)はこの関数\( h \)の出力によって定義される確率分布です.
$$
I \left( \boldsymbol{X} \right) = g \left( \boldsymbol{X}, \boldsymbol{Z}_I \right)
$$
ITE generatorを得る上での目標は, 確率分布\( I \)がpotential outcomesの周辺分布をよりよく近似するような関数\( h \)を学習することです.
・ITE discriminator \( \boldsymbol{D_I} \): ここでは, Counterfactual discreminatorの学習とは異なり, (擬似的ではありますが)必要な変数が全て格納されている完全なデータセット\( \tilde{D} \)にアクセスできるので, conditional GANで用いられるようなdiscreminatorの損失を適用します. つまり, ITE discreminator \( \boldsymbol{D_I}: \mathcal{X} \times \mathcal{Y} \times \mathcal{Y} \rightarrow [0,1] \)は, \( \left( \boldsymbol{X}, \boldsymbol{Y}^* \right) \)を入力とした時に, \( \boldsymbol{Y}^* \)が\( \tilde{D} \)からのサンプルなのかはたまたITE generator \( \boldsymbol{I} \)からのサンプルなのかをよく判別できるように学習します.
結局のところ, ITE予測器を得る部分の最適化問題は次のように定式化されます.
$$
\min_{\boldsymbol{I}} \, \max_{\boldsymbol{D_I}} \; E_{\boldsymbol{X}} \left[ E_{\boldsymbol{Y}^* \sim \mu_{\boldsymbol{Y}} ( \boldsymbol{X} ) } \left[\log \boldsymbol{D_I} ( \boldsymbol{X}, \boldsymbol{Y}^* ) \right] + E_{\boldsymbol{Y}^* \sim \boldsymbol{I} ( \boldsymbol{X} ) } \left[\log ( 1 – \boldsymbol{D_I} ( \boldsymbol{X}, \boldsymbol{Y}^* )) \right] \right]
$$
ここで, \( \mu_{\boldsymbol{Y}} ( \boldsymbol{X} ) \)は特徴量ベクトル\( \boldsymbol{X} \)で条件付けた時の\( \boldsymbol{Y} \)の周辺分布です.
Algorithm
さてGANを用いてITE予測器を構築するための考え方をこれまでに述べてきましたが, 具体的には細かい正則化が入っていたりするので, ここではそれらについて説明します.
Counterfactual Block (\( \boldsymbol{G}, \boldsymbol{D_G} \))
Counterfactual generatorとdiscreminatorのmin-max最適化問題をある経験的なサンプルについて定義すると,
$$
V_{CF} \left( \boldsymbol{x}_i, \boldsymbol{t}_i, \bar{\boldsymbol{y}}_i \right)
= \boldsymbol{t}_i’ \log \left( \boldsymbol{D_G} \left(\boldsymbol{x}_i, \bar{\boldsymbol{y}}_i \right) \right) + (\boldsymbol{1} – \boldsymbol{t}_i )’ \log \left( \boldsymbol{1} – \boldsymbol{D_G} \left(\boldsymbol{x}_i, \bar{\boldsymbol{y}}_i \right) \right)
$$
また実際に観測されるfactual outcomeに対するL2損失も定義します. このL2損失を組み込むことにより, Counterfactual generatorが観測されているfactual outcomeに対して一定の構成精度を担保するようにしてあげます. (\( \mathcal{L} \)の下付き文字\( S \)は’supervised’の頭文字です.)
$$
\mathcal{L}^G_S \left( y_i^{obs}, \tilde{y}^{(t_i)}_i \right) = \left(y_i^{obs} – \tilde{y}^{(t_i)}_i \right)^2
$$
最終的にCounterfactual Blockでは次の損失を\( \boldsymbol{G} \)と\( \boldsymbol{D_G} \)で交互に最適化します.
$$
\min_{ \boldsymbol{D_G} } \; – \sum_{i=1}^{k_G} V_{CF} \left( \boldsymbol{x}_i, \boldsymbol{t}_i, \bar{\boldsymbol{y}}_i \right)
$$
$$
\min_{ \boldsymbol{G} }\; \sum_{i=1}^{k_G} \left[ V_{CF} \left( \boldsymbol{x}_i, \boldsymbol{t}_i, \bar{\boldsymbol{y}}_i \right) + \alpha \cdot \mathcal{L}^G_S \left( y_i^{obs}, \tilde{y}_i \right) \right]
$$
ここで, \( k_G \)はバッチサイズ, \( \alpha \)はhyperparameterです.
ITE Block (\( \boldsymbol{I}, \boldsymbol{D_I} \))
ITE generatorとdiscreminatorのmin-max最適化問題をある経験的なサンプルについて定義すると,
$$
V_{ITE} \left( \boldsymbol{x}_i, \bar{\boldsymbol{y}}_i, \hat{\boldsymbol{y}}_i \right) = \log \left( \boldsymbol{D_I} \left(\boldsymbol{x}_i, \bar{\boldsymbol{y}}_i \right) \right) + \log \left(1 – \boldsymbol{D_I} \left(\boldsymbol{x}_i, \hat{\boldsymbol{y}}_i \right) \right)
$$
またcounterfactual blockで生成される擬似的なデータセット\( \tilde{D} \)に含まれるITEに対する直接的なL2損失も定義します.
$$
\mathcal{L}_S^I \left( \bar{\boldsymbol{y}}_i, \hat{\boldsymbol{y}}_i \right) = \left( ( \bar{y}^{(1)} – \bar{y}^{(0)} ) – ( \hat{y}^{(1)} – \hat{y}^{(0)} ) \right)^2
$$
最終的にITE Blockでは次の損失を\( \boldsymbol{I} \)と\( \boldsymbol{D_I} \)で交互に最適化します.
$$
\min_{ \boldsymbol{D_I} } \; – \sum_{i=1}^{k_I} V_{ITE} \left( \boldsymbol{x}_i, \bar{\boldsymbol{y}}_i, \hat{\boldsymbol{y}}_i \right)
$$
$$
\min_{ \boldsymbol{I} }\; \sum_{i=1}^{k_I} \left[ V_{ITE} \left( \boldsymbol{x}_i, \bar{\boldsymbol{y}}_i, \hat{\boldsymbol{y}}_i\right) + \beta \cdot \mathcal{L}^I_S \left( \bar{\boldsymbol{y}}_i, \hat{\boldsymbol{y}}_i \right) \right]
$$
ここで, \( k_I \)はバッチサイズ, \( \beta \)はhyperparameterです.
論文では今まで説明してきたCounterfactual BlockとITE Blockの学習を擬似コードにまとめてくれています ([Yoon et al. (2018)]のAlgorithm 1).
実験結果
最後に論文で行われている実験結果を部分的に紹介します.
datasets & metrics
論文では, 以下の2つのsemi-syntheticデータと1つのreal-worldデータを用いて精度検証を行っています.
IHDP: このデータは, [Hill. (2011)]で提供されているInfant Health and Development Program (IHDP)で集められたデータが元になっています. この実データを元に人工的にpotential outcomes (\( Y^{(0)}, Y^{(1)} \))が生成されているため, これらをgrount truthとして用いることができます. データ数は747ととても小さく, 介入が割り当てられたのがその内139と介入群のデータが少ないのも特徴です. 各データについて25個の特徴量が格納されています.
Twins: Twinsデータは, [Almond et al. (2005)]によって提供された1989-1991の間にアメリカで生まれた子供に関するデータセットです. 本論文ではこのデータのうち出生時体重が2kg以下のTwins(双子)のペアのみを用います. 双子のうち, 出生時体重の軽かった方を統制群(\( t = 0 \)), 重かった方を介入群(\( t = 1 \))とします. 目的変数は生後1年での死亡率であり, 体重の重かった方の生死と軽かった方の生死の差がITEの実現値になるので, これをground truthとして用います. Twinsデータでは, 各データについて30個の特徴量を用いています.
Jobs: Jobsデータは, [Lalonde. (1986)]で用いられた因果推論ではとても有名な職業訓練に関するデータセットです. このデータセットには, RCTによって収集された722のサンプルと2490の観察データが納められています. 各データについて7個の特徴量が格納されています. 本論文では, 職業訓練がその後の失業有無に対して有する因果効果を予測することを目標にしています. このJobsデータは他の2つのデータとは異なり完全なreal-worldデータなので, testデータにおいてもpotential outcomesが未知です. 論文では, RCTで集められたデータのみを用いて次のように計算されるPolicy riskと呼ばれる評価指標により手法の性能を評価しています.
$$
R_{\text{pol}} = 1 – \frac{1}{N} \sum_{i=1}^N \sum_{t’=0}^k \left[ \frac{1}{| \pi_i \cap T_i \cap E |} \sum_{ \boldsymbol{x} \in \pi_{t’} \cap T_{t’} \cap E } y_i^{(t’)} \times \frac{ |\pi_{t’} \cap E| }{|E|} \right]
$$
ここで, \( \pi_{t’} = \{ \boldsymbol{x} : t’ = \arg \max \hat{\boldsymbol{y}} \} \), \( T_{t’} = \{ \boldsymbol{x} : t’ = t\} \), \( E \)はRCTによって収集されたデータの集合です. 若干計算式がややこしいですが, つまりはITE予測値に基づいて決定された介入方策によって導かれるだろう目的変数の期待値の良さを測っています.
results
前節で述べた3つのデータセットにGANITEとその他の比較手法を適用した時の結果は次の通りだったそうです. ([Yoon et al. (2018)]のTable2) ここで\( \epsilon_{PEHE} \)はITEのground truthに対するMSEだと思ってください.
これを見ると, GANITEはその他の手法と比較してIHDPデータセットではあまり良い性能を示しているとは言えませんが, その他の2つのデータセットではもっとも良い性能を示していることがわかります. 筆者らは, IHDPは全てのデータ数が747であるのに対し他の2つのデータセットは数千~数万程度のデータ数があること, GANITEが多くのhyperparameterを含み学習が難しいこと等から, データ数が多い時に提案のGANITEが有効であると主張していました.
さいごに
本記事では, GANのような敵対的学習を用いて訓練データにおいてさえも観測されないIndividual Treatment Effectを予測する手法について紹介しました. 同様の考え方は欠損データ解析などの文脈でも提案されているようなので, これからも動向を注視していきたいと思っています.
References
[Yoon et al. (2018)]: Jinsung Yoon, James Jordon, and Mihaela van der Schaar. GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets. In International Conference on Learning Representations, 2018.
[Hill (2011)]: Bayesian nonparametric modeling for causal inference. Journal of Computational and Graphical Statistics, 20(1):217–240, 2011.
[Almond et al. (2005)]: Douglas Almond, Kenneth Y Chay, and David S Lee. The costs of low birth weight. The Quarterly Journal of Economics, 120(3):1031–1083, 2005.
[Lalonde (1986)]: Robert J LaLonde. Evaluating the econometric evaluations of training programs with experimental data. The American economic review, pp. 604–620, 1986.
- アイキャッチ画像は, [Yoon et al. (2018)]のFigure 1を引用しました.
Author