すっごいマニアックな備忘録だが、こういうの書いてまとめて自分事として納得しないと秒速で忘却してしまうので。
What's the problem?
(因果)効果の異質性をデータから探索的に見出すためのcausal treeという手法がある。
この手法についての論文も書いてくれている、Susan AtheyさんがRでcausalTreeというパッケージを出してくれている。
- GitHub - susanathey/causalTree: Working repository for Causal Tree and extensions
- Athey and Imbens(2016)
最近は日本語での解説記事・論文も増えていてありがたい限りである
じゃあいいじゃん、わかりやすいし面白い手法だし使いないYO!!めでたしめでたし~
と、なればいいのだが、ひとつ難しいことがあって、学習してつくりあげられた因果木をテストデータでためしたときにフィッティングを測るための方法がない(MSEを出してくれない)
Atheyさんは開発者というよりは研究者の人っぽいなのでcausalTreeパッケージもupdateは4年ほど行われておらず、
そこらへんをユーザーフレンドリーな感じにすることにインセンティブもなかろうなので、
(Imbens大先生と共著の論文とかたくさん出してるし、スーパーエリートなのでしょう)
でも、ハイパラの調整をグリッドサーチとかでやるときとか困るので、出し方を調べた覚書。
いきなり寄り道:Rのrpartについての注意点
自分も普段は決定木/回帰木をつかわない人なので、N年ぶりにrpartと(それに依存したcausalTree)に向き合ったのだが、
(Rで決定木を実行するときの王道パッケージのわりには)結構このパッケージ特殊だな、と感じたので注意点を備忘する。
主にrpartはどうやって木をprune(剪定)しているか、という話について。
rpartについては、ながい紹介記事を開発者が公開しているので、それを読めばだいたいの仕様はわかる(はずだが数学音痴の私には以下略)。
戻り値で$cptableってのがあるんだけど、その内容があんまり理解できていなかったので復習。
↑の12pにはcomplex parameterの定義がある。
k個の終端ノードをもつ木Tに対するrisk functionを
と定義し、を終端ノード数とすると、
終端ノード数に応じた罰則を付加した
という関数を考えることができる。このがcomplex parameterで、0~∞までの値をとる。
ここで生成される木Tをの関数と考える。
すなわち、あるに対してを最小化するようなTをとあらわすことにする。
はfull modelになるし、は分割をもたない木(木とは?)になる。
ここで重要な性質として、
なら、あるいはがのstrict sub treeになる。
が導かれる(らしい)
まぁ簡単にいうと、complex parameterを大きくしていくと、
①大きくする前と同じ木が最適な木として生成される
②大きくする前のsub treeが最適な木として生成される
の二通りしかないってことで、要するにnestedな構造になっている、という含意。
だからnestedな木のsequenseを考えると、のとりうる区間(=0~∞)を以下のように区切ることができて
同一の区間においては同じ区間が生成される。
で、ここまでが長い前置きで、次からpruneのためのcross-validationの話になる。
〈STEP1〉
同じ「最適な木」を生成されるようなcomplex parameterの値域区分
について、以下のような"typical value"を計算することができる
〈STEP2〉
これを計算したうえで、データをs個に分割して、分割により得られた群をとする。
それぞれの分割に対して以下を行う。
- Full Modelを以外のデータセットに適用する。そのときm個のtypical value()を用いて、m個の木をつくる。
- 残しておいたにm個の木を適用して予測(分類 or 回帰)を行う。
- それぞれの木についての、リスクを計算する
〈STEP3〉
各分割について得られたのリスクを足し合わせる。
リスクを最小化するような[\beta]に対応するを最適なtrimmed treeとして選択する。
※ただし、とriskの関係には急落したあと停滞して(プラトーに入る)という関係がみられることが多いので、
CV時にリスクの値だけでなくそのS.E.も計算しておいて、1×S.E.分の変化がない場合はよりsimpleなtreeを最適なtreeとして保持する("1-SE rule")
※ちなみに、rpartではデフォルトではcross-validationにおけるデータ分割はが1つだけのケースして、残りを生成用にしている。
いやーなかなか複雑なこと、やってるんですね。
まぁとりあえず覚えておくべきなのは、
ここで「cross-validation」と言われているのは、あくまでも木の剪定のためのデータ分割であって、
機械学習の本の「交差検証」の項でやられているような予測精度の算定・改善のための交差検証ではない、ということ。
rpart.controlに指定するxvalって木の剪定に使われるほうの「cross-validation」の制御に使われるだけで、われわれが機械学習の色々な局面で行う「交差検証」をrpartがやってくれるわけではないんですな。
たぶんここらへんが分かりにくいので、海外でも沼に迷い込んでしまった以下のような質問が見つかる
stackoverflow.com
stackoverflow.com
たぶん”cross-validation”っていう言葉を使っているのがいけない(まぎらわしい)んだけど、確かにやってることは交差検証なので、しょうがないのかもなと。
金明哲先生のページでもそこらへんのことは詳解されていないので、日本語を母語とする学習者も割と同じ沼にハマりがちなのではないでしょうか。
うーんめんどくさい。
めんどくさいんだけど、rpartの戻り値はこういう内部処理に関する値までちゃんと格納してくれてたほうがわかりやすいんじゃ?と思います。
いや、使ってる側がこんな文句いうのはアレですが。
本題:Causal TreeにおけるMSEの求め方。
さて、長い前置きが終わり、本題に入る。
causal treeの手法そのものについては上記の参考URLが詳しいので、より実践的でspecificな課題として、
causal treeで生成された木からテストデータへのfittingを求めるにはどうすれば?っていうのを理解したい
前提:因果木のつくられかた
上の解説資料でも書かれていることだけど、簡単に。
因果木を従来的な決定木/回帰木と比したとき、technicalな面からいって一番特徴的なのが、予測対象の値がグループごとの真の効果が観測不能であることである。
そこで出てくるのが、Honest推定における学習用サンプルの分割という工夫である。
つまり、通常の機械学習ではデータを学習用/評価用、に分割するが、因果木ではさらに学習用データを分割して
木の分割(split)に使うサンプルと、分割のなかでの効果を推定(estimate)するためのサンプルに分けている。
上のbrief intro(pp.6-7)でsplitting criterion functionとして出ているEMSEは以下の形をしている。
ここで、左辺に「-」がついているのが肝で、EMSEを最小化するためには右辺は最大化したい、ってことになる。
右辺第一項は、無理やり言葉で表現すると
「分割用サンプルのすべてのケースについての(推定用サンプルから得られる)係数推定値の二乗の平均」となる。
これを最大化したいってのは、要するに係数推定値の分散を最大化したい、ということ(X→Yの効果の異質性を捉えたい)である。
右辺第二項におけるは"within-leaf variance"、つまり葉(木の末端ノード)におけるYの分散を基準化したものである。
これを(-1がかけられているので)最小化したい、ということはつまり「ある葉において統制/処置群それぞれYの分散を最小化したい」という基準である。
とりあえずこういった基準をもとに木の分割が行われている、と分かればええのや。
上のintroのp.9からはCVとpruneについての説明(というか基準)が記載されているが、
ここでrpart同様にcausalTreeでのcv.option等のCV関係も基本は剪定のためのパラメータであって、いわゆる我々が想起するような「交差検証」をしているわけではない。
※これがわからんかったから,rpartまで話が遡ったわけですね....
本題:テストデータへのフィッティング求める
ここでやっと本題(前置きが長すぎる)が、causal Treeの場合、学習モデルをテストデータに適用しようとしても、
pred()とかではただ「どの葉に属するか」だけしか分からないので、自分で計算する必要がある。
上のURLの386目あたりで、AtheyさんがMSEを求めるコードを書いてくれている。
ので、それを見てみよう。
#honest estimation of predicted y dataEst$leaff <- as.factor(round(predict(tree_honest_prune_list[[i]], newdata=dataEst,type="vector"),4)) yPredHonestTable <- melt(tapply(dataEst$y, list(dataEst$leaff, dataEst$w), mean), varnames=c("leaff","w")) yPredHonestTable <- rename(yPredHonestTable,replace=c("value"="ypredhon")) dataTest <- merge(dataTest,yPredHonestTable, by.x=c("leaff", "w"), by.y=c("leaff","w")) MSEy_honest[[i]] = mean((dataTest$ypredhon-dataTest$y)^2, na.rm=T)
詳解すると、以下のようなことをやっている(なにぶん昔のコードなのであんまりtidyな感じになっていなくて読みにくい..)
- predict()を使って、テストデータに葉を割り当てる
- leaff(予測された葉)と、w(treatment : 0/1, 上の文に合わせると)に関して推定用データの平均をとって、それを予測値とする
- 葉とtreatmentの値をkeyとしてテストデータに予測値を紐づける
- テストデータのの実測値と紐づけた予測値の差の、二乗の平均をとって、MSEを計算する
二番目の工程は、Honest推定を使っているなら、そもそも推定用データ使って条件付き効果を求めているはずなので、いらないはず...と思ったけど、葉の名前として出力されるのは、条件付き効果ではない(葉ごとEMSEの推定値…なのでvariance項を含んでいたりする)なので、やはり推定用サンプルから求めてくる必要はある。
とりあえず、以下のような関数を実装すればいい(コードは省略)
① causalTreeの戻り値、テストデータ、推定用データを使って 推定データに葉を割り付け
② 葉× 処置変数の値ごとに推定データにおけるの平均値をとることでYの予測値を計算
③ テストデータとYの予測値を、{葉、処置変数の値}をkeyとして紐づける
④ テストデータに紐づけたYの予測値と実測値の差の二乗の平均を求める(これがMSE)
枝や葉ごとの、「処置群のYの平均値」-「統制群のYの平均値」を推定用データから求めて、それを「本来の因果効果」の代わりとして使う、というイメージですね。
まぁ、とりあえず目的は達成できた。
新しい手法が出るたびに、ちゃんとパッケージの実装の中身も確認しないと駄目だなぁ.....
Enjoy!!