論理の流刑地

地獄の底を、爆笑しながら闊歩する

rpartと因果木についての覚え書き

すっごいマニアックな備忘録だが、こういうの書いてまとめて自分事として納得しないと秒速で忘却してしまうので。

What's the problem?

(因果)効果の異質性をデータから探索的に見出すためのcausal treeという手法がある。
この手法についての論文も書いてくれている、Susan AtheyさんがRでcausalTreeというパッケージを出してくれている。

最近は日本語での解説記事・論文も増えていてありがたい限りである

じゃあいいじゃん、わかりやすいし面白い手法だし使いないYO!!めでたしめでたし~
と、なればいいのだが、ひとつ難しいことがあって、学習してつくりあげられた因果木をテストデータでためしたときにフィッティングを測るための方法がない(MSEを出してくれない)

Atheyさんは開発者というよりは研究者の人っぽいなのでcausalTreeパッケージもupdateは4年ほど行われておらず、
そこらへんをユーザーフレンドリーな感じにすることにインセンティブもなかろうなので、
(Imbens大先生と共著の論文とかたくさん出してるし、スーパーエリートなのでしょう)

でも、ハイパラの調整をグリッドサーチとかでやるときとか困るので、出し方を調べた覚書。

いきなり寄り道:Rのrpartについての注意点

自分も普段は決定木/回帰木をつかわない人なので、N年ぶりにrpartと(それに依存したcausalTree)に向き合ったのだが、
(Rで決定木を実行するときの王道パッケージのわりには)結構このパッケージ特殊だな、と感じたので注意点を備忘する。

主にrpartはどうやって木をprune(剪定)しているか、という話について。

rpartについては、ながい紹介記事を開発者が公開しているので、それを読めばだいたいの仕様はわかる(はずだが数学音痴の私には以下略)。
戻り値で$cptableってのがあるんだけど、その内容があんまり理解できていなかったので復習。

↑の12pにはcomplex parameterの定義がある。

k個の終端ノードをもつ木Tに対するrisk functionを
R(T) = \Sigma_{i=1}^k P(T_i) R(T_i)と定義し、|T|を終端ノード数とすると、
終端ノード数に応じた罰則を付加した
R_\alpha (T) = R(T) + \alpha |T|
という関数を考えることができる。この\alphaがcomplex parameterで、0~∞までの値をとる。

ここで生成される木Tを\alphaの関数と考える。
すなわち、ある\alphaに対してR_\alpha(T)を最小化するようなTをT_\alphaとあらわすことにする。
T_0はfull modelになるし、T_\infは分割をもたない木(木とは?)になる。

ここで重要な性質として、
\alpha > \betaなら、T_\alpha = T_\betaあるいはT_\alphaT_\betaのstrict sub treeになる。
が導かれる(らしい)
まぁ簡単にいうと、complex parameterを大きくしていくと、
①大きくする前と同じ木が最適な木として生成される
②大きくする前のsub treeが最適な木として生成される
の二通りしかないってことで、要するにnestedな構造になっている、という含意。

だからnestedな木のsequenseを考えると、\alphaのとりうる区間(=0~∞)を以下のように区切ることができて
同一の区間 \alpha \in I_iにおいては同じ区間が生成される。
f:id:ronri_rukeichi:20210522073825p:plain

で、ここまでが長い前置きで、次からpruneのためのcross-validationの話になる。

〈STEP1〉
同じ「最適な木」を生成されるようなcomplex parameterの値域区分I_i
について、以下のような"typical value"を計算することができる

f:id:ronri_rukeichi:20210522074618p:plain

〈STEP2〉
これを計算したうえで、データをs個に分割して、分割により得られた群をG_1, G_2, ...., G_sとする。
それぞれの分割に対して以下を行う。

  1. Full ModelをG_i以外のデータセットに適用する。そのときm個のtypical value(B_i)を用いて、m個の木T_{B_i}をつくる。
  2. 残しておいたG_iにm個の木を適用して予測(分類 or 回帰)を行う。
  3. それぞれの木についての、リスクを計算する

〈STEP3〉
各分割について得られた\beta_jのリスクを足し合わせる。
リスクを最小化するような[\beta]に対応するT_\betaを最適なtrimmed treeとして選択する。
※ただし、\betaとriskの関係には急落したあと停滞して(プラトーに入る)という関係がみられることが多いので、
 CV時にリスクの値だけでなくそのS.E.も計算しておいて、1×S.E.分の変化がない場合はよりsimpleなtreeを最適なtreeとして保持する("1-SE rule")
※ちなみに、rpartではデフォルトではcross-validationにおけるデータ分割はG_iが1つだけのケースして、残りをT_{B_i}生成用にしている。

いやーなかなか複雑なこと、やってるんですね。
まぁとりあえず覚えておくべきなのは、
ここで「cross-validation」と言われているのは、あくまでも木の剪定のためのデータ分割であって、
機械学習の本の「交差検証」の項でやられているような予測精度の算定・改善のための交差検証ではない、ということ。

rpart.controlに指定するxvalって木の剪定に使われるほうの「cross-validation」の制御に使われるだけで、われわれが機械学習の色々な局面で行う「交差検証」をrpartがやってくれるわけではないんですな。

たぶんここらへんが分かりにくいので、海外でも沼に迷い込んでしまった以下のような質問が見つかる
stackoverflow.com
stackoverflow.com

たぶん”cross-validation”っていう言葉を使っているのがいけない(まぎらわしい)んだけど、確かにやってることは交差検証なので、しょうがないのかもなと。

金明哲先生のページでもそこらへんのことは詳解されていないので、日本語を母語とする学習者も割と同じ沼にハマりがちなのではないでしょうか。

うーんめんどくさい。
めんどくさいんだけど、rpartの戻り値はこういう内部処理に関する値までちゃんと格納してくれてたほうがわかりやすいんじゃ?と思います。
いや、使ってる側がこんな文句いうのはアレですが。

本題:Causal TreeにおけるMSEの求め方。

さて、長い前置きが終わり、本題に入る。
causal treeの手法そのものについては上記の参考URLが詳しいので、より実践的でspecificな課題として、
causal treeで生成された木からテストデータへのfittingを求めるにはどうすれば?っていうのを理解したい

前提:因果木のつくられかた

github.com

上の解説資料でも書かれていることだけど、簡単に。

因果木を従来的な決定木/回帰木と比したとき、technicalな面からいって一番特徴的なのが、予測対象の値がグループごとの真の効果\tauが観測不能であることである。

そこで出てくるのが、Honest推定における学習用サンプルの分割という工夫である。
つまり、通常の機械学習ではデータを学習用/評価用、に分割するが、因果木ではさらに学習用データを分割して
木の分割(split)に使うサンプルと、分割のなかでの効果を推定(estimate)するためのサンプルに分けている。

上のbrief intro(pp.6-7)でsplitting criterion functionとして出ているEMSEは以下の形をしている。

f:id:ronri_rukeichi:20210522142725p:plain
CausalTreeのHonest推定におけるsplit基準

ここで、左辺に「-」がついているのが肝で、EMSEを最小化するためには右辺は最大化したい、ってことになる。

右辺第一項は、無理やり言葉で表現すると
「分割用サンプルのすべてのケースについての(推定用サンプルから得られる)係数推定値の二乗の平均」となる。
これを最大化したいってのは、要するに係数推定値の分散を最大化したい、ということ(X→Yの効果の異質性を捉えたい)である。

右辺第二項におけるS^2_{S_{treat/control}^{tr}}は"within-leaf variance"、つまり葉(木の末端ノード)におけるYの分散を基準化したものである。
これを(-1がかけられているので)最小化したい、ということはつまり「ある葉において統制/処置群それぞれYの分散を最小化したい」という基準である。

とりあえずこういった基準をもとに木の分割が行われている、と分かればええのや。

上のintroのp.9からはCVとpruneについての説明(というか基準)が記載されているが、
ここでrpart同様にcausalTreeでのcv.option等のCV関係も基本は剪定のためのパラメータであって、いわゆる我々が想起するような「交差検証」をしているわけではない
※これがわからんかったから,rpartまで話が遡ったわけですね....

本題:テストデータへのフィッティング求める

ここでやっと本題(前置きが長すぎる)が、causal Treeの場合、学習モデルをテストデータに適用しようとしても、
pred()とかではただ「どの葉に属するか」だけしか分からないので、自分で計算する必要がある。

github.com

上のURLの386目あたりで、AtheyさんがMSEを求めるコードを書いてくれている。
ので、それを見てみよう。

 #honest estimation of predicted y
  dataEst$leaff <- as.factor(round(predict(tree_honest_prune_list[[i]], newdata=dataEst,type="vector"),4))
  yPredHonestTable <- melt(tapply(dataEst$y, list(dataEst$leaff, dataEst$w), mean), varnames=c("leaff","w"))
  yPredHonestTable <- rename(yPredHonestTable,replace=c("value"="ypredhon"))
  dataTest <- merge(dataTest,yPredHonestTable, by.x=c("leaff", "w"), by.y=c("leaff","w"))
  MSEy_honest[[i]] = mean((dataTest$ypredhon-dataTest$y)^2, na.rm=T)


詳解すると、以下のようなことをやっている(なにぶん昔のコードなのであんまりtidyな感じになっていなくて読みにくい..)

  1. predict()を使って、テストデータに葉を割り当てる
  2. leaff(予測された葉)と、w(treatment : 0/1, 上の文に合わせるとX)に関して推定用データの平均をとって、それを予測値 \hat{\mu}(l, X)とする
  3. 葉とtreatmentの値をkeyとしてテストデータに予測値を紐づける
  4. テストデータのYの実測値と紐づけた予測値の差の、二乗の平均をとって、MSEを計算する

二番目の工程は、Honest推定を使っているなら、そもそも推定用データ使って条件付き効果を求めているはずなので、いらないはず...と思ったけど、葉の名前として出力されるのは、条件付き効果ではない(葉ごとEMSEの推定値…なのでvariance項を含んでいたりする)なので、やはり推定用サンプルから求めてくる必要はある。

とりあえず、以下のような関数を実装すればいい(コードは省略)
① causalTreeの戻り値、テストデータ、推定用データを使って 推定データに葉を割り付け
② 葉× 処置変数の値ごとに推定データにおけるYの平均値をとることでYの予測値を計算
③ テストデータとYの予測値\hat{y}_iを、{葉、処置変数の値}をkeyとして紐づける
④ テストデータに紐づけたYの予測値\hat{y}_iと実測値y_iの差の二乗の平均を求める(これがMSE)

枝や葉ごとの、「処置群のYの平均値」-「統制群のYの平均値」を推定用データから求めて、それを「本来の因果効果」の代わりとして使う、というイメージですね。


まぁ、とりあえず目的は達成できた。
新しい手法が出るたびに、ちゃんとパッケージの実装の中身も確認しないと駄目だなぁ.....


www.youtube.com


Enjoy!!