見出し画像

PettingZoo 入門 (1) - マルチエージェント強化学習環境セット

マルチエージェント強化学習環境セット「PettingZoo」の基本的な使い方をまとめました。

・Python 3.8.12
・Stable Baselines 1.6.0
・PettingZoo 1.20.1

1. PettingZoo

「PettingZoo」は、マルチエージェント強化学習環境セットです。「OpenAI Gym」のマルチエージェント版のようなものになります。

2. Stable Baselines 3とPettingZooのインストール

「Stable Baselines 3」と「PettingZoo」のインストールのインストール手順は、次のとおりです。

(1) Pythonの仮想環境を準備。
「Python 3.7以降」をインストールします。

WindowsでのPythonの開発環境の準備
MacでのPythonの開発環境の準備

(2) 「Stable Baselines 3」のインストール。

$ pip install 'stable-baselines3[extra]'

(3) 「PettingZoo」のインストール。
今回は「Pistonball」を使うので「Butterfly」をインストールします。

$ pip install 'pettingzoo[butterfly]'

(4) 「SuperSuit」のインストール。
「SuperSuite」は、強化学習環境をラップして前処理を行う小さな関数群を提供するパッケージです。「OpenAI Gym」と「PettingZoo」をサポートしています。

$ pip install supersuit

4. Pistonballの学習

「Pistonball」はカードを左右移動させて、棒を倒さないようにバランスをとるゲームです。

(1) 「Pistonball」の学習および推論を行うコードの作成。

・train_pistonball.py

import supersuit as ss
from pettingzoo.butterfly import pistonball_v6
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy

# 学習環境の準備
env = pistonball_v6.parallel_env(
    n_pistons=20,
    time_penalty=-0.1,
    continuous=True,
    random_drop=True,
    random_rotate=True,
    ball_mass=0.75,
    ball_friction=0.3,
    ball_elasticity=1.5,
    max_cycles=125,
)

# 学習環境の前処理の付加
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3")

# モデルの準備
model = PPO(
    CnnPolicy,
    env,
    verbose=3,
    gamma=0.95,
    n_steps=256,
    ent_coef=0.0905168,
    learning_rate=0.00062211,
    vf_coef=0.042202,
    max_grad_norm=0.9,
    gae_lambda=0.99,
    n_epochs=5,
    clip_range=0.3,
    batch_size=256,
)

# 学習の実行
model.learn(total_timesteps=2000000)

# モデルの保存
model.save("policy")

# 推論用の学習環境の準備
env = pistonball_v6.env()
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)

# モデルの読み込み
model = PPO.load("policy")

# 推論の実行
env.reset()
for agent in env.agent_iter():
    # 観察、報酬、エピソード完了、情報の取得
    obs, reward, done, info = env.last()

    # モデルの推論
    act = model.predict(obs, deterministic=True)[0] if not done else None

    # 1ステップ実行
    env.step(act)

    # 学習状況の描画
    env.render()

(2) 「Pistonball」の学習および推論を行うコードの実行。

$ python train_pistonball.py

学習中は、学習状況のログが出力されます。

UsUsing cpu device
Wrapping the env in a VecTransposeImage.
------------------------------
| time/              |       |
|    fps             | 1471  |
|    iterations      | 1     |
|    time_elapsed    | 27    |
|    total_timesteps | 40960 |
------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 526         |
|    iterations           | 2           |
|    time_elapsed         | 155         |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.015285267 |
|    clip_fraction        | 0.0431      |
|    clip_range           | 0.3         |
|    entropy_loss         | -1.51       |
|    explained_variance   | 0.00631     |
|    learning_rate        | 0.000622    |
|    loss                 | 0.112       |
|    n_updates            | 5           |
|    policy_gradient_loss | 0.00325     |
|    std                  | 1.14        |
|    value_loss           | 12.4        |
-----------------------------------------

           :

-----------------------------------------
| time/                   |             |
|    fps                  | 351         |
|    iterations           | 49          |
|    time_elapsed         | 5706        |
|    total_timesteps      | 2007040     |
| train/                  |             |
|    approx_kl            | 0.007980399 |
|    clip_fraction        | 0.033       |
|    clip_range           | 0.3         |
|    entropy_loss         | -5.3        |
|    explained_variance   | 0.0196      |
|    learning_rate        | 0.000622    |
|    loss                 | 1.59        |
|    n_updates            | 240         |
|    policy_gradient_loss | 0.00187     |
|    std                  | 49          |
|    value_loss           | 53.2        |
-----------------------------------------

学習後は、学習済みモデルで推論が実行され、画面表示で動作確認できます。

次回





この記事が気に入ったらサポートをしてみませんか?