Google Colaboratory + PyStanを使ってみた話
この記事について
Stan Advent Calendar 2020 第14日目の記事です。
14日目のカレンダーが空いていたので、何か埋めなきゃということで、「そういえば使ったことないから、Python環境でStanを使ってみよう」という思い付きで書いています。
普段はR環境でStanを使っているので、全てが手探りです。そんなわけで、この記事は自分のための備忘録です。世の中にはもっと体系的にまとまった記事がたくさんあるので、ぜひそれらを参考にしてください。
今回はGoogle Colaboratory上で、PyStanを用いることにします。なお今回はGoogle Driveのマウントはしませんが、Google ColaboratoryをGoogle Driveと連携させることで、より使いやすくなると思います。
実演
.stanファイルを用意する
これはRやPythonとは関係のない作業なので*1、好きなエディタを用いて書きます。ここでは説明変数が1つの、単回帰モデルを例とします。
個のデータがあり、目的変数
を、説明変数
で予測しています。
が切片、
が回帰係数、
が残差です。
これは以下の確率モデルと同義です。
つまり、
というモデルを立てたことになります。以下のStanコードを、regression.stan
という名前で保存しておきます。
data { int<lower=0> n_data; //サンプルサイズ real y[n_data]; //目的変数 vector[n_data] x; //説明変数 } parameters { real beta_0; //切片 real beta_1; //回帰係数 real<lower=0> sigma; //標準偏差 } model { y ~ normal(beta_0 + beta_1*x, sigma); }
Google Colaboratory上に.stanファイルをアップロードする
自分のPCにPython環境がなくても、インターネットブラウザ上でPythonコードが実行できる、Google Colaboratoryを使用します。ここにアクセスして、[ファイル → ノートブックを新規作成]します。
ローカルの.stanファイルは、以下の手順で容易にアップロードできます。
「注: アップロードしたファイルはランタイムのリサイクル時に削除されます。」というメッセージが現れると思いますが、これはGoogle Colaboratoryを使う以上仕方のないことです。
アップロードされた場所は、対象のファイル名にカーソルを合わせて、右端の「・・・」をクリックすると現れるメニューのうち、「パスをコピー」を選択するとクリップボードに格納されるので、知ることができます。実際にやってみると、/content/regression.stan
であることが分かります。
ライブラリの読み込み
最低限、これらのライブラリを読み込めば良いと思います。PyStanのサンプリング結果を、トレースプロットやチェインごとの事後分布といった形で可視化するためには、arviz
というライブラリが適しているようですが、Google Colaboratoryではまず!pip install arviz
でインストールする必要があります。
import numpy as np #サンプルデータを作成する際に、randomモジュールで乱数生成するため import pystan #stanのPython用インタフェース # 可視化用ライブラリ ---------------------- %matplotlib inline import matplotlib.pyplot as plt import seaborn as sns # PyStanのサンプリング結果の可視化用ライブラリ ------------------ !pip install arviz # arvizというライブラリをまずインストールする import arviz
サンプルデータの作成
単回帰モデルを真のモデルとして、パラメータの真値を指定してサンプルデータを作成します。
n = 50 #サンプルサイズ np.random.seed(seed = 1234) #乱数のシード x = np.random.randint(low = 1, high = 30, size = n) #説明変数。1~29の一様乱数(整数)をn個生成 y = np.random.normal(loc = 20 + 1.5 * x, scale = 3.0, size = n) #目的変数。正規分布に従い、その平均は20 + 1.5x。標準偏差は3.0
可視化してみます。
# サンプルデータの可視化
sns.jointplot(x = x, y = y)
.stanファイルのコンパイルとサンプリング
まずはアップロード済みの.stanファイルをコンパイルします。
sm = pystan.StanModel(file = '/content/regression.stan')
次に、サンプリングです。Rだと、rstan::sampling()
関数の中で、object = sm
とコンパイルしたオブジェクトを指定しますが、Pythonだとちょっと書き方が違うことに注意です。
また、Rでは、rstan::sampling()
関数の中で、list型でデータを渡しますが、PyStanでは辞書型でデータを渡す必要があるところも違いますね。辞書型のデータはdict()
などで作成できます。
その他の引数はRStanでもPyStanでもほとんど同じですね。
fit = sm.sampling( data = dict( n_data = n, y = y, x = x ), seed = 1234, iter = 2000, warmup = 1000, chains = 4)
収束診断とサンプリング結果の出力
まずはうまく収束したかどうか、視覚的に確認してみましょう。ここではarviz
ライブラリのplot_trace()
を用いて、トレースプロットやチェインごとの事後分布を可視化します。
arviz.plot_trace(fit)
収束していると考えてよさそうです。
要約統計量は以下の通りです。うまくパラメータリカバリできていますね。
print(fit)
おわり
PyStan、初めて使ってみたんですが、Google Colaboratoryを併用することで、環境構築の手間もほとんどいらず、手軽に利用することが出来ました。
Enjoy !!
*1:もっとも、RStudioの文法チェック機能が優れているので、慣れないうちはRStudio上で書くのが良いような気はします