Stable Baselines入門 / エージェントの保存と読み込み
1. エージェントの保存と読み込み
エージェントの保存と読み込みを行います。
保存はagent.save()、読み込みはPPO2.load()を使います。
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
# エージェントの生成
agent = PPO2(MlpPolicy, env, verbose=1)
# エージェントの学習
agent.learn(total_timesteps=10000)
# エージェントの保存
agent.save("sample")
# エージェントの削除
del agent
# エージェントの読み込み
agent = PPO2.load("sample")
# テスト
state = env.reset()
for i in range(200):
action, _ = agent.predict(state)
state, reward, done, info = env.step(action)
env.render()
この記事が気に入ったらサポートをしてみませんか?