この記事について
Stan Advent Calendar 2020 9日目の記事です。本記事で紹介する階層線形モデルは、第4日目に紹介された重回帰モデルや、第5日目に紹介されたロジスティック回帰モデルなどの、一般線形モデルや一般化線形モデルを拡張したモデルになります。Stanで階層線形モデルの確率モデルを書く方法を紹介します。
なおこの記事では階層線形モデルという名称を使用しますが、線形混合モデルやマルチレベルモデルとも本質的には同じです。
一般線形モデルの復習
ここでは説明変数が1つの、単回帰モデルを例とします。
個のデータがあり、目的変数を、説明変数で予測しています。が切片、が回帰係数、が残差です。
これは以下の確率モデルと同義です。
つまり、
というモデルを立てたことになります。
階層線形モデルへの拡張
データの階層性を考慮すべき状況
StanとRでベイズ統計モデリングの121ページから引用すると、階層モデルとは
説明変数だけでは説明がつかない、グループに由来する差異(グループ差)を上手く扱うための一手法
です。
「グループ」というのは例えば、
実験参加者ID | グループ |
---|---|
1 | ◇◇大学 |
2 | ◇◇大学 |
3 | ◇◇大学 |
4 | ▲▲大学 |
5 | ▲▲大学 |
6 | ▲▲大学 |
7 | ☆☆大学 |
8 | ☆☆大学 |
9 | ☆☆大学 |
などです。グループは上表のような「所属集団」に限らず、以下のように個人をグループと見なすことも可能です。下表では、「実験参加者ID」がグループに相当しています。
実験参加者ID | 試行番号 |
---|---|
1 | 1 |
1 | 2 |
1 | 3 |
2 | 1 |
2 | 2 |
2 | 3 |
3 | 1 |
3 | 2 |
3 | 3 |
上の2つの表ではいずれも、9つのデータが得られていることになっています。では、目的変数と説明変数も同様に9つずつデータが得られているとき、
という回帰モデルを考えてよいかというと...そうとは限りません。ここで、データの階層性を考慮しなければならない場合があります。その理由は、以下のスライドの12枚目に詳しいです。
階層線形モデルの確率モデル
上記の回帰モデルを階層線形モデルに拡張したい場合、色々な拡張の方法がありますが、典型的には以下の3通りが考えられます。
- 切片のみに、グループ差を考慮する
- 回帰係数のみに、グループ差を考慮する
- 切片と回帰係数に、グループ差を考慮する
ここでは一例として、「3. 切片と回帰係数に、グループ差を考慮する」場合の階層線形モデルの確率モデルを紹介します。
添え字は一つひとつのデータを、添え字は一つひとつのグループを表します。上の表でいえば(一つ目の表)、添え字は一人一人の実験参加者を、添え字は各大学を表しています。切片パラメータと回帰係数パラメータに添え字が付いて、とになっているので、グループごとにこれらのパラメータが異なることを仮定しています。
では、グループごとに異なると仮定した切片と回帰係数は、どのように、どれくらいグループごとに異なるのでしょうか。それを反映しているのが、
です。つまりグループごとの切片と回帰係数の違いは、正規分布に従うというモデルを立てたことになります。
実演
まずは、パラメータの真値と真のモデルが分かっている状況からサンプルデータを生成します。
library(ggplot2) library(rstan) set.seed(1234) n_data = 100 #総データ数 n_group = 5 #グループ数 n_each_group = n_data / n_group #各グループ内の人数 = 20 mu_0 = rnorm(n = n_group, mean = 30, sd = 10) #切片は、平均30、標準偏差10の正規分布に従う mu_1 = rnorm(n = n_group, mean = 3, sd = 1.5) #回帰係数、平均3、標準偏差1.5の正規分布に従う x = runif(n = n_data, min = 1, max = 30) #説明変数 y = rnorm(n = n_data, mean = (mu_0 + mu_1*x), sd = 15) #目的変数 group = rep(1:n_group, time = n_each_group) #グループを識別する変数
このようなサンプルデータの作り方は、StanとRでベイズ統計モデリングに載っています。
グループごとに回帰直線を描くと、こんな感じです。確かに切片も傾きも、グループごとに異なっています。
ggplot(data = data.frame(x, y, group), mapping = aes(x = x, y = y, color = factor(group))) + geom_point() + geom_smooth(method = "lm", se = FALSE)
この確率モデルを指定してパラメータを推定したら、真値に近くなることが期待できるはずです。やってみましょう。Stanコードは以下です。
方法1:回帰モデルを拡張した書き方
data { int<lower=0> n_data; //データ数 int<lower=0> n_group; //グループ数 real y[n_data]; //目的変数 vector[n_data] x; //説明変数 int group[n_data]; //グループを識別する変数 } parameters { vector[n_group] beta[2]; //beta[1]が各グループの切片 //beta[2]が各グループの回帰係数 real mu[2]; //mu[1]が各グループの切片の平均 //mu[2]が各グループの回帰係数の平均 real<lower=0> sigma[3]; //sigma[1]が各グループの切片の標準偏差 //sigma[2]が各グループの回帰係数の標準偏差 //sigma[3]が目的変数が従う正規分布の標準偏差 } model { beta[1] ~ normal(mu[1], sigma[1]); beta[2] ~ normal(mu[2], sigma[2]); y ~ normal(beta[1, group] + beta[2, group].*x, sigma[3]); }
.*
という演算子が用いられていることに注意してください。これは同じ長さのベクトルを、要素ごとに掛算するための演算子です。
これをmixed_1.stan
という名前で保存して、以下のRコードからサンプリングを行います。
sm_1 = rstan::stan_model("mixed_1.stan") fit_1 = rstan::sampling(object = sm_1, data = list( n_data = n_data, n_group = n_group, y = y, x = x, group = group ), seed = 1234, iter = 4000, warmup = 1000) print(fit_1)
トレースプロットやchainごとの事後分布を視覚的にチェックした限りでは、収束していると判断して良さそうです。
事後分布の要約統計量は以下の通りです。
- グループごとの切片が従う正規分布の平均]は、真値が30のところ、事後中央値は28.27
- グループごとの切片が従う正規分布の標準偏差]は、真値が10のところ、事後中央値は18.26
- グループごとの回帰係数が従う正規分布の平均]は、真値が3のところ、事後中央値は2.86
- グループごとの回帰係数が従う正規分布の標準偏差]は、真値が1.5のところ、事後中央値は1.62
- 目的変数が従う正規分布の標準偏差]は、真値が15のところ、事後中央値は15.22
グループごとの切片が従う正規分布の標準偏差]だけちょっと真値と事後中央値が離れていますが、おおむね良くパラメータリカバリできているように思います。
方法2:多変量正規分布を用いた書き方
上記の階層モデルは、多変量正規分布を用いて、以下のように書くこともできます。多変量正規分布は、Stan Advent Calendar 2020 第6日目の記事でも紹介されています。
※前述の確率モデルに対して、違うStanコードの書き方ができるということではなく、確率モデルも多変量正規分布を用いた別のものに変わっていることに注意してください
data { int<lower=0> n_data; //データ数 int<lower=0> n_group; //グループ数 int<lower=0> n_slope; //グループ差を考慮する回帰係数の数(切片含む) real y[n_data]; //目的変数 matrix[n_data, n_slope] X; //説明変数(切片含む)の行列 int group[n_data]; //グループを識別する変数 } parameters { vector[n_slope] beta[n_group]; //グループごとの、回帰係数(切片含む)のベクトル real<lower=0> sigma; //目的変数が従う正規分布の標準偏差 vector[n_slope] mu; //各回帰係数(切片含む)の平均ベクトル cov_matrix[n_slope] Cov; //多変量正規分布の分散共分散行列 } model { for(n in 1:n_data){ y[n] ~ normal(X[n]*beta[group[n]], sigma); } beta ~ multi_normal(mu, Cov); }
modelブロックのy[n] ~ normal(X[n]*beta[group[n]], sigma)
は、行列の掛算によって、目的変数が従う正規分布の平均を計算しています。重回帰モデルにおける同様の書き方は、Stan Advent Calendar2020の4日目の記事でも紹介されています。
ポイントは、各回帰係数(切片を含む)がvector[n_slope] beta[n_group]
で宣言されているところです。回帰係数の数(切片を含む)の要素を持つbeta
というパラメータが、グループの数(n_group
)だけあると宣言しています。
modelブロックのbeta ~ multi_normal(mu, Cov)
は、各回帰係数(切片を含む)beta
が、多変量正規分布に従うことを仮定しています。多変量正規分布のパラメータは、以下の2つです。
- 平均ベクトル
mu
- 分散共分散行列
Cov
要するに、それぞれの回帰係数は同時に正規分布に従っており、かつ互いに相関があるというモデルになっています。
ただし多変量正規分布の分散共分散行列を、quad_form_diag()
というStanの関数を用いて以下のように定義すると、サンプリング効率がよくなるようです。
data { int<lower=0> n_data; //データ数 int<lower=0> n_group; //グループ数 int<lower=0> n_slope; //グループ差を考慮する回帰係数の数(切片含む) real y[n_data]; //目的変数 matrix[n_data, n_slope] X; //説明変数(切片含む)の行列 int group[n_data]; //グループを識別する変数 } parameters { vector[n_slope] beta[n_group]; //グループごとの、回帰係数(切片含む)のベクトル real<lower=0> sigma; //目的変数が従う正規分布の標準偏差 vector[n_slope] mu; //各回帰係数(切片含む)の平均ベクトル corr_matrix[n_slope] Omega; //それぞれの正規分布同士の相関行列 vector<lower=0>[n_slope] tau; //それぞれの正規分布の標準偏差 } model { for(n in 1:n_data){ y[n] ~ normal(X[n]*beta[group[n]], sigma); } beta ~ multi_normal(mu, quad_form_diag(Omega, tau)); Omega ~ lkj_corr(2); //相関行列の弱情報事前分布。 }
Stan モデリング言語: ユーザーガイド・リファレンスマニュアル(日本語翻訳版)によればquad_form_diag()
という関数は、
quad_form_diag(Sigma,tau)
がdiag_matrix(tau) * Sigma * diag_matrix(tau)
と等価になるように定義されています. ここで,diag_matrix_(tau)
は, 対角成分がtau
となり, それ以外が0の行列を返します
という働きを持ちます(上記のStanコードでは、Sigma
というパラメータをOmega
という名前で宣言しています)。
これで本当に分散共分散行列が作れることを確かめてみましょう。今回のように、グループ差を仮定している回帰係数が2つの場合(切片と、説明変数の回帰係数の2つ)、以下のような計算になります。
対角成分は、標準偏差に相当するの2乗なので、分散を表します。非対角成分は、標準偏差の積に、相関係数を掛けているので、共分散を表します。確かに、分散共分散行列が作れていますね。
それでは上記のStanコードをmixed_2.stan
という名前で保存し、以下のRコードを実行してみましょう。使用しているサンプルデータは、mixed_1.stan
に渡したものと同じです。
model.matrix()
は、引数にformula(回帰式)を指定すると、説明変数を行列形式でまとめてくれる関数です。今回は切片も回帰係数とみなしているので、全ての値が1の説明変数を追加する必要があるため、X = model.matrix(y ~ 1 + x)
という書き方にしています(ただしX = model.matrix(y ~ x)
でも同じです)。
X = model.matrix(y ~ 1 + x) # = の左側は大文字のX、右側は小文字のxであることに注意 fit_2 = rstan::sampling(object = sm_2, data = list( n_data = n_data, n_group = n_group, n_slope = ncol(X), y = y, X = X, # = の左右ともに、大文字のXであることに注意 group = group ), seed = 1234, iter = 4000, warmup = 1000)
パラメータ数が多いので、トレースプロットやチェインごとの事後分布は掲載しませんが、収束していると判断して良さそうです。事後分布の要約統計量は以下の通りです。
- グループごとの切片が従う正規分布の平均]は、真値が30のところ、事後中央値は28.15
- グループごとの切片が従う正規分布の標準偏差]は、真値が10のところ、事後中央値は19.41
- グループごとの回帰係数が従う正規分布の平均]は、真値が3のところ、事後中央値は2.85
- グループごとの回帰係数が従う正規分布の標準偏差]は、真値が1.5のところ、事後中央値は1.67
- 目的変数が従う正規分布の標準偏差は、真値が15のところ、事後中央値は15.20
相変わらず、グループごとの切片が従う正規分布の標準偏差]だけちょっと真値と事後中央値が離れていますが、おおむね良くパラメータリカバリできていますね。
まとめ
いかがだったでしょうか
パラメータに階層性を仮定したモデルは、ちょっとだけ書き方が難しくなりますが、基本的には回帰モデルの拡張と考えられます。実際のデータ分析では、階層性を考慮したほうがよい状況も多いので、参考になれば幸いです。
Enjoy!