見出し画像

Stable Baselines入門 / エージェントの保存と読み込み

1. エージェントの保存と読み込み

エージェントの保存と読み込みを行います。
保存はagent.save()、読み込みはPPO2.load()を使います。

import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

# エージェントの生成
agent = PPO2(MlpPolicy, env, verbose=1)

# エージェントの学習
agent.learn(total_timesteps=10000)

# エージェントの保存
agent.save("sample")

# エージェントの削除
del agent

# エージェントの読み込み
agent = PPO2.load("sample")

# テスト
state = env.reset()
for i in range(200):
   action, _ = agent.predict(state)
   state, reward, done, info = env.step(action)
   env.render()


この記事が気に入ったらサポートをしてみませんか?