見出し画像

Stable Baselines 3 入門 (2) - Monitor

「Stable Baselines 3」の「Monitor」の使い方をまとめました。

・Python 3.8.12
・Stable Baselines 1.6.0
・gym 0.21.0

前回

1. Monitor

「Monitor」は、「報酬」(r)「エピソード長」(l)「時間」(t)をログ出力するためのラッパーです。使い方は、EnvをMonitorでラップするだけです。

import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor

# ログフォルダの準備
log_dir = './logs/'
os.makedirs(log_dir, exist_ok=True)

# 学習環境の準備
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True)  # Monitorの利用

# モデルの準備
model = PPO('MlpPolicy', env, verbose=1)

# 学習の実行
model.learn(total_timesteps=128000)

# 推論の実行
state = env.reset()
while True:
    # 学習環境の描画
    env.render()

    # モデルの推論
    action, _ = model.predict(state, deterministic=True)

    # 1ステップ実行
    state, rewards, done, info = env.step(action)

    # エピソード完了
    if done:
        break

# 学習環境の解放
env.close()

実行すると、logフォルダに以下のログが出力されます。

・monitor.csv

#{"t_start": 1661003542.954774, "env_id": "CartPole-v1"}
r,l,t
12.0,12,0.138155
46.0,46,0.155652
11.0,11,0.159838
    :

2. グラフのプロット

グラフのプロットには、「qt」が必要なのでインストールします。

$ pip install PyQt5

グラフをプロットするコードは、次のとおりです。

・monitor_plot.py

import pandas as pd
import matplotlib.pyplot as plt

# monitor.csvの読み込み
df = pd.read_csv('./logs/monitor.csv', names=['r', 'l','t'])
df = df.drop(range(2)) # 1〜2行目の削除

# 報酬のプロット
x = range(len(df['r']))
y = df['r'].astype(float)
plt.plot(x, y)
plt.xlabel('episode')
plt.ylabel('reward')
plt.show()

# エピソード長のプロット
x = range(len(df['l']))
y = df['l'].astype(float)
plt.plot(x, y)
plt.xlabel('episode')
plt.ylabel('episode len')
plt.show()

実行すると、グラフがプロットされます。

$ python monitor_plot.py



この記事が気に入ったらサポートをしてみませんか?