Google Colaboratory + PyStanを使ってみた話

この記事について

Stan Advent Calendar 2020 第14日目の記事です。
14日目のカレンダーが空いていたので、何か埋めなきゃということで、「そういえば使ったことないから、Python環境でStanを使ってみよう」という思い付きで書いています。

普段はR環境でStanを使っているので、全てが手探りです。そんなわけで、この記事は自分のための備忘録です。世の中にはもっと体系的にまとまった記事がたくさんあるので、ぜひそれらを参考にしてください。

今回はGoogle Colaboratory上で、PyStanを用いることにします。なお今回はGoogle Driveのマウントはしませんが、Google ColaboratoryをGoogle Driveと連携させることで、より使いやすくなると思います。

実演

.stanファイルを用意する

これはRやPythonとは関係のない作業なので*1、好きなエディタを用いて書きます。ここでは説明変数が1つの、単回帰モデルを例とします。


y_i=\beta_0+\beta_1x_i+\varepsilon_i \qquad i = 1, 2, ..., n \\
\varepsilon_i \sim {\rm Normal}(0, \sigma)

 n個のデータがあり、目的変数 yを、説明変数 x_1で予測しています。 \beta_0が切片、 \beta_1が回帰係数、 \varepsilonが残差です。
これは以下の確率モデルと同義です。


y_i \sim {\rm Normal}(\beta_0+\beta_1x_i, \sigma) \qquad i = 1, 2, ..., n \\

つまり、

目的変数 y_iは、正規分布に従う。その平均は \beta_0+\beta_1x_iで、標準偏差 \sigmaである

というモデルを立てたことになります。以下のStanコードを、regression.stanという名前で保存しておきます。

data {
  int<lower=0> n_data;   //サンプルサイズ
  real y[n_data];        //目的変数
  vector[n_data]  x;     //説明変数
}

parameters {
  real beta_0;           //切片
  real beta_1;           //回帰係数
  real<lower=0> sigma;   //標準偏差
}

model {
  y ~ normal(beta_0 + beta_1*x, sigma);
}

Google Colaboratory上に.stanファイルをアップロードする

自分のPCにPython環境がなくても、インターネットブラウザ上でPythonコードが実行できる、Google Colaboratoryを使用します。ここにアクセスして、[ファイル → ノートブックを新規作成]します。

f:id:das_Kino:20201211103021p:plain

ローカルの.stanファイルは、以下の手順で容易にアップロードできます。 f:id:das_Kino:20201211103808p:plain

「注: アップロードしたファイルはランタイムのリサイクル時に削除されます。」というメッセージが現れると思いますが、これはGoogle Colaboratoryを使う以上仕方のないことです。

アップロードされた場所は、対象のファイル名にカーソルを合わせて、右端の「・・・」をクリックすると現れるメニューのうち、「パスをコピー」を選択するとクリップボードに格納されるので、知ることができます。実際にやってみると、/content/regression.stanであることが分かります。

f:id:das_Kino:20201211104119p:plain

ライブラリの読み込み

最低限、これらのライブラリを読み込めば良いと思います。PyStanのサンプリング結果を、トレースプロットやチェインごとの事後分布といった形で可視化するためには、arvizというライブラリが適しているようですが、Google Colaboratoryではまず!pip install arvizでインストールする必要があります。

import numpy as np #サンプルデータを作成する際に、randomモジュールで乱数生成するため
import pystan #stanのPython用インタフェース

# 可視化用ライブラリ ----------------------
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

# PyStanのサンプリング結果の可視化用ライブラリ ------------------
!pip install arviz # arvizというライブラリをまずインストールする
import arviz

サンプルデータの作成

単回帰モデルを真のモデルとして、パラメータの真値を指定してサンプルデータを作成します。

n = 50 #サンプルサイズ

np.random.seed(seed = 1234) #乱数のシード
x = np.random.randint(low = 1, high = 30, size = n) #説明変数。1~29の一様乱数(整数)をn個生成
y = np.random.normal(loc = 20 + 1.5 * x,
                     scale = 3.0,
                     size = n) #目的変数。正規分布に従い、その平均は20 + 1.5x。標準偏差は3.0

可視化してみます。

# サンプルデータの可視化
sns.jointplot(x = x, y = y) 

f:id:das_Kino:20201211110044p:plain

.stanファイルのコンパイルとサンプリング

まずはアップロード済みの.stanファイルをコンパイルします。

sm = pystan.StanModel(file = '/content/regression.stan')

次に、サンプリングです。Rだと、rstan::sampling()関数の中で、object = smコンパイルしたオブジェクトを指定しますが、Pythonだとちょっと書き方が違うことに注意です。

また、Rでは、rstan::sampling()関数の中で、list型でデータを渡しますが、PyStanでは辞書型でデータを渡す必要があるところも違いますね。辞書型のデータはdict()などで作成できます。

その他の引数はRStanでもPyStanでもほとんど同じですね。

fit = sm.sampling(
    data = dict(
        n_data = n,
        y = y,
        x = x
        ),
    seed = 1234,
    iter = 2000,
    warmup = 1000,
    chains = 4)

収束診断とサンプリング結果の出力

まずはうまく収束したかどうか、視覚的に確認してみましょう。ここではarvizライブラリのplot_trace()を用いて、トレースプロットやチェインごとの事後分布を可視化します。

arviz.plot_trace(fit)

f:id:das_Kino:20201211111419p:plain

収束していると考えてよさそうです。
要約統計量は以下の通りです。うまくパラメータリカバリできていますね。

print(fit)

f:id:das_Kino:20201211111202p:plain

おわり

PyStan、初めて使ってみたんですが、Google Colaboratoryを併用することで、環境構築の手間もほとんどいらず、手軽に利用することが出来ました。

Enjoy !!

*1:もっとも、RStudioの文法チェック機能が優れているので、慣れないうちはRStudio上で書くのが良いような気はします