見出し画像

Stable Baselines入門 / Atariゲーム

1.Atariゲーム

Atariゲームの環境には、専用の環境設定を行うmake_atari_env()が用意されています。さらに、フレームスキップを行うために、VecFrameStackが用意されています。

◎make_atari_env()
make_atari_env()は、Atari用のラップされた監視対象のSubprocVecEnvの生成するメソッドです。

make_atari_env(env_id, num_env, seed, wrapper_kwargs=None,
   start_index=0, allow_early_resets=True, start_method=None)
   env_id: (str) 環境ID
   num_env: (int) 環境数
   seed: (int) 乱数シード
   wrapper_kwargs: (dict) wrap_deepmind関数のパラメータ
   start_index: (int) 開始インデックス
   allow_early_resets: (bool) 環境の早期リセットの有効
   start_method: (str) サブプロセスを開始するために使用するメソッド
   return: (Gym Environment) Atari環境

◎VecFrameStack
VecFrameStackは、環境にフレームスキップを追加するラッパーです。

Atariでは1秒間に画面が60回更新されます。毎フレーム行動選択するのは、計算コストの面で効率的ではありません。そこで、4フレームに1回行動選択するようにします。スキップするフレームでは前回取った行動をリピートするようにします。

Atariゲーム「PongNoFrameskip-v4」を学習させるコードは、次の通りです。

from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2

# 環境の生成
env = make_atari_env('PongNoFrameskip-v4', num_env=4, seed=0)
env = VecFrameStack(env, n_stack=4)

# エージェントの生成
agent = PPO2(MlpPolicy, env, verbose=1)

# エージェントの学習
agent.learn(total_timesteps=25000)

# テスト
state = env.reset()
for i in range(200):
   env.render()
   action, _ = agent.predict(state)
   state, reward, done, info = env.step(action)

画像1


この記事が気に入ったらサポートをしてみませんか?