見出し画像

StanとRでベイズ統計モデリングをPyMC Ver.5で写経~第7章「7.3 非線形の関係」

第7章「回帰分析の悩みどころ」

書籍の著者 松浦健太郎 先生


この記事は、テキスト第7章「回帰分析の悩みどころ」の7.3節「非線形の関係」の PyMC5写経 を取り扱います。
「二次曲線の例」と指数が関係する「時系列データの例」の2例に取り組みます。

はじめに


StanとRでベイズ統計モデリングの紹介

この記事は書籍「StanとRでベイズ統計モデリング」(共立出版、「テキスト」と呼びます)のベイズモデルを用いて、PyMC Ver.5で「実験的」に写経する翻訳的ドキュメンタリーです。

テキストは、2016年10月に発売され、ベイズモデリングのモデル式とプログラミングに関する丁寧な解説とモデリングの改善ポイントを網羅するチュートリアル「実践解説書」です。もちろん素晴らしいです!
アヒル本」の愛称で多くのベイジアンに愛されてきた書籍です!

テキストに従ってStanとRで実践する予定でしたが、RのStan環境を整えることができませんでした(泣)
そこでこのシリーズは、テキストのベイズモデルをPyMC Ver.5に書き換えて実践します。

引用表記

この記事は、出典に記載の書籍に掲載された文章及びコードを引用し、適宜、掲載文章とコードを改変して書いています。
【出典】
「StanとRでベイズ統計モデリング」初版第13刷、著者 松浦健太郎、共立出版

記事中のイラストは、「かわいいフリー素材集いらすとや」さんのイラストをお借りしています。
ありがとうございます!

PyMC環境の準備

Anacondaを用いる環境構築とGoogle ColaboratoryでPyMCを動かす方法について、次の記事にまとめています。
「PyMCを動かすまでの準備」章をご覧ください。


7.3 非線形の関係


インポート

### インポート

# 数値・確率計算
import pandas as pd
import numpy as np

# PyMC
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# 描画
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'Meiryo'

# ワーニング表示の抑制
import warnings
warnings.simplefilter('ignore')

二次曲線の例

データの読み込み

サンプルコードのデータを読み込みます。

### データの読み込み ◆data-aircon.txt
# Y:ある家庭の夏から冬にかけての1日あたりのエアコン消費電力(kWh),
# X:屋外の平均気温(℃)

data = pd.read_csv('./data/data-aircon.txt')
print('data.shape: ', data.shape)
display(data.head())

【実行結果】

データの要約統計量と相関係数を算出します。

### 要約統計量の表示
data.describe().round(2)

【実行結果】

### 相関係数の表示
data.corr().round(3)

【実行結果】

散布図を描画します。
テキスト図7.5左に相当します。

### 散布図の描画 ◆図7.5左

# 散布図の描画
sns.lmplot(data=data, x='X', y='Y',
           scatter_kws={'s': 80, 'alpha': 0.4},
           line_kws={'color': 'tomato','alpha': 0.4})
# 修飾
plt.title('放物線状のプロット')
plt.grid(lw=0.5)
plt.show()

【実行結果】
回帰直線を入れてみましたが、まったく線形にフィットしていないことが分かります。

モデルの構築

Y の予測に用いる X の値を設定します。

### Yの予測分布に用いるXの値の設定
X_new = np.linspace(-3, 32, 60)

モデルの定義です。

### モデルの定義 ◆モデル式7-3

with pm.Model() as model1:
    
    ### データ関連定義
    ## coordの定義
    model1.add_coord('data', values=data.index, mutable=True)
    model1.add_coord('dataNew', values=range(len(X_new)), mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data['Y'].values, dims='data')
    # 説明変数 X
    X = pm.ConstantData('X', value=data['X'].values, dims='data')
    # 予測用の説明変数 XNew
    XNew = pm.ConstantData('XNew', value=X_new, dims='dataNew')

    ### 事前分布
    a = pm.Uniform('a', lower=-1000, upper=1000)
    b = pm.Uniform('b', lower=-100, upper=100)
    x0 = pm.Uniform('x0', lower=0, upper=30)  # 快適温度:範囲を限定
    sigma = pm.Uniform('sigma', lower=0, upper=1000)

    ### mu
    mu = pm.Deterministic('mu', a + b * pt.pow(X - x0, 2), dims='data')
    
    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=Y, dims='data')

    ### 計算値
    yNew = pm.Normal('yNew', mu=a + b * pt.pow(XNew - x0, 2), sigma=sigma,
                     dims='dataNew')

【ポイント】
尤度の正規分布の平均パラメータ mu に用いる二次曲線の数式です。
$${Y}$$は消費電力、$${X}$$は屋外の平均気温、$${x_0}$$は快適温度です。

$$
Y = a + b\ (X - x_0)^2
$$

モデルの定義内容を見ます。

### モデルの表示
model1

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model1)

【実行結果】

MCMCを実行します。

### 事後分布からのサンプリング 30秒
with model1:
    idata1  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.8,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata1        # idata名
threshold = 1.01         # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['a', 'b', 'sigma', 'x0', 'mu']
pm.summary(idata1, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
var_names = ['a', 'b', 'sigma', 'x0', 'mu', 'yNew']
pm.plot_trace(idata1, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】

パラメータの事後統計量の要約を算出します。

### パラメータの要約を確認

# mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

# 要約統計量の算出・表示
vars = ['a', 'b', 'sigma', 'x0']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(2))

【実行結果】
快適温度$${X_0}$$は$${18^{\circ} \text{C}}$$程度のようです。

テキスト図7.5の散布図&予測分布を描画します。

### 散布図の描画 ◆図7.5

## 描画用データの作成
# MCMCサンプリングデータからyNewを取り出し
y_pred_samples = idata1.posterior.yNew.stack(sample=('chain','draw')).data
# y_predの中央値、95%CI、50%CIを算出
y_pred_median = np.median(y_pred_samples, axis=1)
y_pred_95ci = np.quantile(y_pred_samples, q=[0.025, 0.975], axis=1)
y_pred_50ci = np.quantile(y_pred_samples, q=[0.250, 0.750], axis=1)

## 描画処理
# 描画領域の設定
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

# 左の散布図の描画:観測値のみ
sns.scatterplot(ax=ax1, data=data, x='X', y='Y', s=80, alpha=0.3)
ax1.set(title='y の観測値のプロット', )
ax1.grid(lw=0.5)

# 右の散布図の描画:観測値と予測値
sns.scatterplot(ax=ax2, data=data, x='X', y='Y', s=80, alpha=0.3, label='観測値')
ax2.plot(X_new, y_pred_median, color='tab:red', label='予測値(中央値)')
ax2.fill_between(X_new, y_pred_95ci[0], y_pred_95ci[1], color='tomato',
                 alpha=0.2, label='95%CI')
ax2.fill_between(X_new, y_pred_50ci[0], y_pred_50ci[1], color='tomato',
                 alpha=0.5, label='50%CI')
ax2.set(title='y の観測値と予測値のプロット')
ax2.grid(lw=0.5)
ax2.legend(bbox_to_anchor=(1.43, 1))
plt.tight_layout();

【実行結果】

パラメータを描画しましょう。
まずは事後分布プロットから。

### 事後分布プロットの描画
var_names = ['a', 'b', 'sigma', 'x0']
pm.plot_posterior(idata1, hdi_prob=0.95, var_names=var_names, round_to=3,
                  figsize=(10, 3))
plt.tight_layout();

【実行結果】

次はフォレストプロットです。

### フォレストプロットの描画
var_names = ['a', 'b', 'sigma', 'x0']
pm.plot_forest(idata1, combined=True, hdi_prob=0.95, var_names=var_names,
               figsize=(5, 3))
plt.axvline(0, color='tab:red', ls='--')
plt.grid(lw=0.3);

【実行結果】

時系列データの例

データの読み込み

サンプルコードのデータを読み込みます。

### データの読み込み ◆data-conc.txt
# Time: 薬剤投与からの経過時間(hour)
# Y: 薬の血中濃度(mg/mL)

data2 = pd.read_csv('./data/data-conc.txt')
print('data2.shape: ', data2.shape)
display(data2)

【実行結果】

データの要約統計量を算出します。

### 要約統計量の表示
data2.describe().round(2)

【実行結果】

時系列プロットをを描画します。
テキスト図7.6左に相当します。

### 時系列プロットの描画 ◆図7.6左

# 描画領域の設定
ax = plt.subplot()
ax.plot(data2.Time, data2.Y, '-o')
ax.set(title='放物線状のプロット', xlabel='$Time$ (hour)', ylabel='$Y$',
       xticks=data2.Time.values, ylim=(-1, 16), yticks=(0, 5, 10, 15))
plt.grid(lw=0.5)
plt.show()

【実行結果】
対数のグラフのように見えます。

モデルの構築

Y の予測に用いる Time の値を設定します。

### Yの予測分布に用いるTimeの値の設定
Time_new = np.linspace(0, 24, 60)

モデルの定義です。

### モデルの定義 ◆モデル式7-4 model7-4.stan

with pm.Model() as model2:
    
    ### データ関連定義
    ## coordの定義
    model2.add_coord('data', values=data2.index, mutable=True)
    model2.add_coord('dataNew', values=range(len(Time_new)), mutable=True)
    ## dataの定義
    # 目的変数 Y
    Y = pm.ConstantData('Y', value=data2['Y'].values, dims='data')
    # 説明変数 Time
    Time = pm.ConstantData('Time', value=data2['Time'].values, dims='data')
    # 予測用の説明変数 XNew
    TimeNew = pm.ConstantData('XNew', value=Time_new, dims='dataNew')

    ### 事前分布
    a = pm.Uniform('a', lower=0, upper=100)
    b = pm.Uniform('b', lower=0, upper=5)           # 範囲に縛り
    sigma = pm.Uniform('sigma', lower=0, upper=10)  # 範囲に縛り

    ### mu
    mu = pm.Deterministic('mu', a * (1 - pt.exp(-b * Time)), dims='data')
    
    ### 尤度関数
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=Y, dims='data')

    ### 計算値
    yNew = pm.Normal('yNew', mu=a * (1 - pt.exp(-b * TimeNew)), sigma=sigma,
                     dims='dataNew')

【ポイント】
尤度の正規分布の平均パラメータ mu に用いる二次曲線の数式です。
$${y}$$は薬の血中濃度、$${t}$$は投与からの経過時間です。

$$
y = a\ \{1 - \exp\ (-b\ t)\}
$$

この曲線は$${t=0}$$のときに$${y=0}$$です。
$${t}$$が大きくなるにつれて$${y}$$は急激に大きくなり、$${y=a}$$で頭打ちになります。
テキストでは、$${a}$$を頭打ちの大きさを決めるパラメータ、$${b}$$を頭打ちになるまでの時間を決めるパラメータ、と説明しています。

モデルの定義内容を見ます。

### モデルの表示
model2

【実行結果】

### モデルの可視化
pm.model_to_graphviz(model2)

【実行結果】

MCMCを実行します。

### 事後分布からのサンプリング 20秒
with model2:
    idata2  = pm.sample(draws=1000, tune=1000, chains=4, target_accept=0.85,
                        nuts_sampler='numpyro', random_seed=1234)

【実行結果】省略

Pythonで事後分布からのサンプリングデータの確認を行います。
Rhatの確認から。
テキストの収束条件は「chainを3以上にして$${\hat{R}<1.1}$$のとき」です。

### r_hat>1.1の確認
# 設定
idata_in = idata2        # idata名
threshold = 1.1          # しきい値

# しきい値を超えるR_hatの個数を表示
print((az.rhat(idata_in) > threshold).sum())

【実行結果】
収束条件を満たしています。

事後統計量を表示します。

### 推論データの要約統計情報の表示
var_names = ['a', 'b', 'sigma', 'mu']
pm.summary(idata2, hdi_prob=0.95, var_names=var_names, round_to=3)

【実行結果】

トレースプロットを描画します。

### トレースプロットの表示
var_names = ['a', 'b', 'sigma', 'mu', 'yNew']
pm.plot_trace(idata1, compact=True, var_names=var_names)
plt.tight_layout();

【実行結果】

パラメータの事後統計量の要約を算出します。

### パラメータの要約を確認

# mean,sd,2.5%,25%,50%,75%,97.5%パーセンタイル点をデータフレーム化する関数の定義
def make_stats_df(y):
    probs = [2.5, 25, 50, 75, 97.5]
    columns = ['mean', 'sd'] + [str(s) + '%' for s in probs]
    quantiles = pd.DataFrame(np.percentile(y, probs, axis=0).T, index=y.columns)
    tmp_df = pd.concat([y.mean(axis=0), y.std(axis=0), quantiles], axis=1)
    tmp_df.columns=columns
    return tmp_df

# 要約統計量の算出・表示
vars = ['a', 'b', 'sigma']
param_samples = idata1.posterior[vars].to_dataframe().reset_index(drop=True)
display(make_stats_df(param_samples).round(2))

【実行結果】
頭打ちの大きさ$${a}$$は$${14}$$前後のようです。

テキスト図7.6の散布図&予測分布を描画します。

### 散布図の描画 ◆図7.6

## 描画用データの作成
# MCMCサンプリングデータからyNewを取り出し
y_pred_samples2 = idata2.posterior.yNew.stack(sample=('chain','draw')).data
# y_predの中央値、95%CI、50%CIを算出
y_pred_median2 = np.median(y_pred_samples2, axis=1)
y_pred_95ci2 = np.quantile(y_pred_samples2, q=[0.025, 0.975], axis=1)
y_pred_50ci2 = np.quantile(y_pred_samples2, q=[0.250, 0.750], axis=1)

## 描画処理
# 描画領域の設定
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

# 左の時系列プロットの描画:観測値のみ
ax1.plot(data2.Time, data2.Y, '-o')
ax1.set(title='y の観測値のプロット', xlabel='$Time$ (hour)', ylabel='$Y$',
       xticks=data2.Time.values, ylim=(-1, 17), yticks=(0, 5, 10, 15))
ax1.grid(lw=0.5)

# 右の時系列プロットの描画:観測値と予測値
ax2.plot(data2.Time, data2.Y, 'o', label='観測値')
ax2.plot(Time_new, y_pred_median2, color='tab:red', label='予測値(中央値)')
ax2.fill_between(Time_new, y_pred_95ci2[0], y_pred_95ci2[1], color='tomato',
                 alpha=0.2, label='95%CI')
ax2.fill_between(Time_new, y_pred_50ci2[0], y_pred_50ci2[1], color='tomato',
                 alpha=0.5, label='50%CI')
ax2.set(title='y の観測値と予測値のプロット', xlabel='$Time$ (hour)', ylabel='$Y$')
ax2.grid(lw=0.5)
ax2.legend(bbox_to_anchor=(1.43, 1))
plt.tight_layout();

【実行結果】

パラメータを描画しましょう。
まずは事後分布プロットから。

### 事後分布プロットの描画
var_names = ['a', 'b', 'sigma']
pm.plot_posterior(idata2, hdi_prob=0.95, var_names=var_names, round_to=3,
                  figsize=(10, 3))
plt.tight_layout();

【実行結果】

次はフォレストプロットです。

### フォレストプロットの描画
var_names = ['a', 'b', 'sigma']
pm.plot_forest(idata2, combined=True, hdi_prob=0.95, var_names=var_names,
               figsize=(5, 3))
plt.axvline(0, color='tab:red', ls='--')
plt.grid(lw=0.3);

【実行結果】

テキスト図7.7の時系列における非線形曲線の例を描画します。

### 時系列における非線形曲線の例 ◆図7.7

## 描画用データの作成
# 時間tの値の設定
t_val = np.linspace(0, 5, 1001)
# Model1のyの算出
y_model1 = 2 * np.exp(-t_val)
# Model2のyの算出
y_model2 = 1.8 / (1 + 50 * np.exp(-2 * t_val))
# Model3のyの算出
y_model3 = 8 * (np.exp(-t_val) - np.exp(-2 * t_val))

## 描画処理
# 描画領域の設定
ax = plt.subplot()
# Model1の曲線の描画
ax.plot(t_val, y_model1, color='tab:blue', label=r'1: $y=2\exp(-t)$')
# Model2の曲線の描画
ax.plot(t_val, y_model2, color='tab:red', ls='--',
        label=r'2: $y=1.8/\{1 + 50\exp(-2t)\}$')
# Model3の曲線の描画
ax.plot(t_val, y_model3, color='tab:green', ls='-.',
        label=r'3: $y=8\{\exp(-t) - \exp(-2t)\}$')
# 修飾
ax.set(xlabel='$Time$', ylabel='$y$', title='時系列における非線形曲線の例',
       yticks=np.arange(0, 2.1, 0.5))
ax.legend(bbox_to_anchor=(1.55, 1), title='Model')
ax.grid(lw=0.2);

【実行結果】

7.3 節は以上です。


シリーズの記事

次の記事

前の記事

目次

ブログの紹介


note で7つのシリーズ記事を書いています。
ぜひ覗いていってくださいね!

1.のんびり統計

統計検定2級の問題集を手がかりにして、確率・統計をざっくり掘り下げるブログです。
雑談感覚で大丈夫です。ぜひ覗いていってくださいね。
統計検定2級公式問題集CBT対応版に対応しています。
Python、EXCELのサンプルコードの配布もあります。

2.実験!たのしいベイズモデリング1&2をPyMC Ver.5で

書籍「たのしいベイズモデリング」・「たのしいベイズモデリング2」の心理学研究に用いられたベイズモデルを PyMC Ver.5で描いて分析します。
この書籍をはじめ、多くのベイズモデルは R言語+Stanで書かれています。
PyMCの可能性を探り出し、手軽にベイズモデリングを実践できるように努めます。
身近なテーマ、イメージしやすいテーマですので、ぜひぜひPyMCで動かして、一緒に楽しみましょう!

3.実験!岩波データサイエンス1のベイズモデリングをPyMC Ver.5で

書籍「実験!岩波データサイエンスvol.1」の4人のベイジアンによるベイズモデルを PyMC Ver.5で描いて分析します。
この書籍はベイズプログラミングのイロハをざっくりと学ぶことができる良書です。
楽しくPyMCモデルを動かして、ベイズと仲良しになれた気がします。
みなさんもぜひぜひPyMCで動かして、一緒に遊んで学びましょう!

4.楽しい写経 ベイズ・Python等

ベイズ、Python、その他の「書籍の写経活動」の成果をブログにします。
主にPythonへの翻訳に取り組んでいます。
写経に取り組むお仲間さんのサンプルコードになれば幸いです🍀

5.RとStanではじめる心理学のための時系列分析入門 を PythonとPyMC Ver.5 で

書籍「RとStanではじめる心理学のための時系列分析入門」の時系列分析をPythonとPyMC Ver.5 で実践します。
この書籍には時系列分析のテーマが盛りだくさん!
時系列分析の懐の深さを実感いたしました。
大好きなPythonで楽しく時系列分析を学びます。

6.データサイエンスっぽいことを綴る

統計、データ分析、AI、機械学習、Pythonのコラムを不定期に綴っています。
統計・データサイエンス書籍にまつわる記事が多いです。
「統計」「Python」「数学とPython」「R」のシリーズが生まれています。

7.Python機械学習プログラミング実践記

書籍「Python機械学習プログラミング PyTorch & scikit-learn編」を学んだときのさまざまな思いを記事にしました。
この書籍は、scikit-learnとPyTorchの教科書です。
よかったらぜひ、お試しくださいませ。

最後までお読みいただきまして、ありがとうございました。

この記事が参加している募集

仕事について話そう

この記事が気に入ったらサポートをしてみませんか?