見出し画像

Unityプロジェクト「ML-Agents:Penguin」まとめ(おまけ)

0.はじめに

ここでは、Unityプロジェクト「ML-Agents:Penguin」で扱ってきたソフトのダウンロードページやプログラム等のデータのまとめです。
長いですが、プログラムの中身を載せてあります。自分で作ったプログラムと見比べたり、ダウンロードして差し替えたり、自由に活用していただけたらと思います。

1.ソフトウェア

・Unity

・Anaconda

・ML-Agents(var 0.13.1)

・ML-Agents:Penguin用素材

・Visual Studio 2017


2.プログラム

・PenguinAcademy.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class PenguinAcademy : Academy
{
   /// Gets/sets the current fish speed
   public float FishSpeed { get; private set; }

   /// Gets/sets the current acceptable feed radius
   public float FeedRadius { get; private set; }
   
   /// Called when the academy first gets initialized
   public override void InitializeAcademy()
   {
       FishSpeed = 0f;
       FeedRadius = 0f;

       // Set up code to be called every time the fish_speed parameter changes 
       // during curriculum learning
       FloatProperties.RegisterCallback("fish_speed", f =>
       {
           FishSpeed = f;
       });

       // Set up code to be called every time the feed_radius parameter changes 
       // during curriculum learning
       FloatProperties.RegisterCallback("feed_radius", f =>
       {
           FeedRadius = f;
       });
   }
}

・PenguinArea.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using TMPro;

public class PenguinArea : Area
{
   [Tooltip("The agent inside the area")]
   public PenguinAgent penguinAgent;

   [Tooltip("The baby penguin inside the area")]
   public GameObject penguinBaby;

   [Tooltip("The TextMeshPro text that shows the cumulative reward of the agent")]
   public TextMeshPro cumulativeRewardText;

   [Tooltip("Prefab of a live fish")]
   public Fish fishPrefab;

   private PenguinAcademy penguinAcademy;
   private List<GameObject> fishList;

   /// Reset the area, including fish and penguin placement
   public override void ResetArea()
   {
       RemoveAllFish();
       PlacePenguin();
       PlaceBaby();
       SpawnFish(4, penguinAcademy.FishSpeed);
   }

   /// Remove a specific fish from the area when it is eaten
   /// <param name="fishObject">The fish to remove</param>
   public void RemoveSpecificFish(GameObject fishObject)
   {
       fishList.Remove(fishObject);
       Destroy(fishObject);
   }

   /// The number of fish remaining
   public int FishRemaining
   {
       get { return fishList.Count; }
   }


   /// Choose a random position on the X-Z plane within a partial donut shape
   /// <param name="center">The center of the donut</param>
   /// <param name="minAngle">Minimum angle of the wedge</param>
   /// <param name="maxAngle">Maximum angle of the wedge</param>
   /// <param name="minRadius">Minimum distance from the center</param>
   /// <param name="maxRadius">Maximum distance from the center</param>
   /// <returns>A position falling within the specified region</returns>
   public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
   {
       float radius = minRadius;
       float angle = minAngle;

       if (maxRadius > minRadius)
       {
           // Pick a random radius
           radius = UnityEngine.Random.Range(minRadius, maxRadius);
       }

       if (maxAngle > minAngle)
       {
           // Pick a random angle
           angle = UnityEngine.Random.Range(minAngle, maxAngle);
       }

       // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
       return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
   }

   /// Remove all fish from the area
   private void RemoveAllFish()
   {
       if (fishList != null)
       {
           for (int i = 0; i < fishList.Count; i++)
           {
               if (fishList[i] != null)
               {
                   Destroy(fishList[i]);
               }
           }
       }

       fishList = new List<GameObject>();
   }

   /// Place the penguin in the area
   private void PlacePenguin()
   {
       Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
       rigidbody.velocity = Vector3.zero;
       rigidbody.angularVelocity = Vector3.zero;
       penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
       penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
   }

   /// Place the baby in the area
   private void PlaceBaby()
   {
       Rigidbody rigidbody = penguinBaby.GetComponent<Rigidbody>();
       rigidbody.velocity = Vector3.zero;
       rigidbody.angularVelocity = Vector3.zero;
       penguinBaby.transform.position = ChooseRandomPosition(transform.position, -45f, 45f, 4f, 9f) + Vector3.up * .5f;
       penguinBaby.transform.rotation = Quaternion.Euler(0f, 180f, 0f);
   }

   /// Spawn some number of fish in the area and set their swim speed
   /// <param name="count">The number to spawn</param>
   /// <param name="fishSpeed">The swim speed</param>
   private void SpawnFish(int count, float fishSpeed)
   {
       for (int i = 0; i < count; i++)
       {
           // Spawn and place the fish
           GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
           fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
           fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

           // Set the fish's parent to this area's transform
           fishObject.transform.SetParent(transform);

           // Keep track of the fish
           fishList.Add(fishObject);

           // Set the fish speed
           fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
       }
   }

   /// Called when the game starts
   private void Start()
   {
       penguinAcademy = FindObjectOfType<PenguinAcademy>();
       ResetArea();
   }

   /// Called every frame
   private void Update()
   {
       // Update the cumulative reward text
       cumulativeRewardText.text = penguinAgent.GetCumulativeReward().ToString("0.00");
   }
}

・PenguinAgents.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class PenguinAgent : Agent
{
   [Tooltip("How fast the agent moves forward")]
   public float moveSpeed = 5f;

   [Tooltip("How fast the agent turns")]
   public float turnSpeed = 180f;

   [Tooltip("Prefab of the heart that appears when the baby is fed")]
   public GameObject heartPrefab;

   [Tooltip("Prefab of the regurgitated fish that appears when the baby is fed")]
   public GameObject regurgitatedFishPrefab;

   private PenguinArea penguinArea;
   private PenguinAcademy penguinAcademy;
   new private Rigidbody rigidbody;
   private GameObject baby;

   private bool isFull; // If true, penguin has a full stomach
   private float feedRadius = 0f;

   /// Initial setup, called when the agent is enabled
   public override void InitializeAgent()
   {
       base.InitializeAgent();
       penguinArea = GetComponentInParent<PenguinArea>();
       penguinAcademy = FindObjectOfType<PenguinAcademy>();
       baby = penguinArea.penguinBaby;
       rigidbody = GetComponent<Rigidbody>();
   }

   /// Perform actions based on a vector of numbers
   /// <param name="vectorAction">The list of actions to take</param>
   public override void AgentAction(float[] vectorAction)
   {
       // Convert the first action to forward movement
       float forwardAmount = vectorAction[0];

       // Convert the second action to turning left or right
       float turnAmount = 0f;
       if (vectorAction[1] == 1f)
       {
           turnAmount = -1f;
       }
       else if (vectorAction[1] == 2f)
       {
           turnAmount = 1f;
       }

       // Apply movement
       rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
       transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

       // Apply a tiny negative reward every step to encourage action
       AddReward(-1f / agentParameters.maxStep);
   }

   /// Read inputs from the keyboard and convert them to a list of actions.
   /// This is called only when the player wants to control the agent and has set
   /// Behavior Type to "Heuristic Only" in the Behavior Parameters inspector.
   /// <returns>A vectorAction array of floats that will be passed into <see cref="AgentAction(float[])"/></returns>
   public override float[] Heuristic()
   {
       float forwardAction = 0f;
       float turnAction = 0f;
       if (Input.GetKey(KeyCode.W))
       {
           // move forward
           forwardAction = 1f;
       }
       if (Input.GetKey(KeyCode.A))
       {
           // turn left
           turnAction = 1f;
       }
       else if (Input.GetKey(KeyCode.D))
       {
           // turn right
           turnAction = 2f;
       }

       // Put the actions into an array and return
       return new float[] { forwardAction, turnAction };
   }

   /// Reset the agent and area
   public override void AgentReset()
   {
       isFull = false;
       penguinArea.ResetArea();
       feedRadius = penguinAcademy.FeedRadius;
   }

   /// Collect all non-Raycast observations
   public override void CollectObservations()
   {
       // Whether the penguin has eaten a fish (1 float = 1 value)
       AddVectorObs(isFull);

       // Distance to the baby (1 float = 1 value)
       AddVectorObs(Vector3.Distance(baby.transform.position, transform.position));

       // Direction to baby (1 Vector3 = 3 values)
       AddVectorObs((baby.transform.position - transform.position).normalized);

       // Direction penguin is facing (1 Vector3 = 3 values)
       AddVectorObs(transform.forward);

       // 1 + 1 + 3 + 3 = 8 total values
   }

   private void FixedUpdate()
   {
       // Test if the agent is close enough to to feed the baby
       if (Vector3.Distance(transform.position, baby.transform.position) < feedRadius)
       {
           // Close enough, try to feed the baby
           RegurgitateFish();
       }
   }

   /// When the agent collides with something, take action
   /// <param name="collision">The collision info</param>
   private void OnCollisionEnter(Collision collision)
   {
       if (collision.transform.CompareTag("fish"))
       {
           // Try to eat the fish
           EatFish(collision.gameObject);
       }
       else if (collision.transform.CompareTag("baby"))
       {
           // Try to feed the baby
           RegurgitateFish();
       }
   }

   /// Check if agent is full, if not, eat the fish and get a reward
   /// <param name="fishObject">The fish to eat</param>
   private void EatFish(GameObject fishObject)
   {
       if (isFull) return; // Can't eat another fish while full
       isFull = true;

       penguinArea.RemoveSpecificFish(fishObject);

       AddReward(1f);
   }

   /// Check if agent is full, if yes, feed the baby
   private void RegurgitateFish()
   {
       if (!isFull) return; // Nothing to regurgitate
       isFull = false;

       // Spawn regurgitated fish
       GameObject regurgitatedFish = Instantiate<GameObject>(regurgitatedFishPrefab);
       regurgitatedFish.transform.parent = transform.parent;
       regurgitatedFish.transform.position = baby.transform.position;
       Destroy(regurgitatedFish, 4f);

       // Spawn heart
       GameObject heart = Instantiate<GameObject>(heartPrefab);
       heart.transform.parent = transform.parent;
       heart.transform.position = baby.transform.position + Vector3.up;
       Destroy(heart, 4f);

       AddReward(1f);

       if (penguinArea.FishRemaining <= 0)
       {
           Done();
       }
   }
}

・Fish.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Fish : MonoBehaviour
{
   [Tooltip("The swim speed")]
   public float fishSpeed;

   private float randomizedSpeed = 0f;
   private float nextActionTime = -1f;
   private Vector3 targetPosition;

   /// Called every timestep
   private void FixedUpdate()
   {
       if (fishSpeed > 0f)
       {
           Swim();
       }
   }

   /// Swim between random positions
   private void Swim()
   {
       // If it's time for the next action, pick a new speed and destination
       // Else, swim toward the destination
       if (Time.fixedTime >= nextActionTime)
       {
           // Randomize the speed
           randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);

           // Pick a random target
           targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);

           // Rotate toward the target
           transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);

           // Calculate the time to get there
           float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
           nextActionTime = Time.fixedTime + timeToGetThere;
       }
       else
       {
           // Make sure that the fish does not swim past the target
           Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
           if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
           {
               transform.position += moveVector;
           }
           else
           {
               transform.position = targetPosition;
               nextActionTime = Time.fixedTime;
           }
       }
   }
}

3.ML-Agents-0.13.1に追加内容

・trainer_config.yaml

default:
   trainer: ppo
   batch_size: 1024
   beta: 5.0e-3
   buffer_size: 10240
   epsilon: 0.2
   hidden_units: 128
   lambd: 0.95
   learning_rate: 3.0e-4
   learning_rate_schedule: linear
   max_steps: 5.0e4
   memory_size: 256
   normalize: false
   num_epoch: 3
   num_layers: 2
   time_horizon: 64
   sequence_length: 64
   summary_freq: 1000
   use_recurrent: false
   vis_encode_type: simple
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.99

FoodCollector:
   normalize: false
   beta: 5.0e-3
   batch_size: 1024
   buffer_size: 10240
   max_steps: 1.0e5

Bouncer:
   normalize: true
   max_steps: 1.0e6
   num_layers: 2
   hidden_units: 64

PushBlock:
   max_steps: 5.0e4
   batch_size: 128
   buffer_size: 2048
   beta: 1.0e-2
   hidden_units: 256
   summary_freq: 2000
   time_horizon: 64
   num_layers: 2

SmallWallJump:
   max_steps: 1.0e6
   batch_size: 128
   buffer_size: 2048
   beta: 5.0e-3
   hidden_units: 256
   summary_freq: 2000
   time_horizon: 128
   num_layers: 2
   normalize: false

BigWallJump:
   max_steps: 1.0e6
   batch_size: 128
   buffer_size: 2048
   beta: 5.0e-3
   hidden_units: 256
   summary_freq: 2000
   time_horizon: 128
   num_layers: 2
   normalize: false

Striker:
   max_steps: 5.0e5
   learning_rate: 1e-3
   batch_size: 128
   num_epoch: 3
   buffer_size: 2000
   beta: 1.0e-2
   hidden_units: 256
   summary_freq: 2000
   time_horizon: 128
   num_layers: 2
   normalize: false

Goalie:
   max_steps: 5.0e5
   learning_rate: 1e-3
   batch_size: 320
   num_epoch: 3
   buffer_size: 2000
   beta: 1.0e-2
   hidden_units: 256
   summary_freq: 2000
   time_horizon: 128
   num_layers: 2
   normalize: false

Pyramids:
   summary_freq: 2000
   time_horizon: 128
   batch_size: 128
   buffer_size: 2048
   hidden_units: 512
   num_layers: 2
   beta: 1.0e-2
   max_steps: 5.0e5
   num_epoch: 3
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.99
       curiosity:
           strength: 0.02
           gamma: 0.99
           encoding_size: 256

VisualPyramids:
   time_horizon: 128
   batch_size: 64
   buffer_size: 2024
   hidden_units: 256
   num_layers: 1
   beta: 1.0e-2
   max_steps: 5.0e5
   num_epoch: 3
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.99
       curiosity:
           strength: 0.01
           gamma: 0.99
           encoding_size: 256

3DBall:
   normalize: true
   batch_size: 64
   buffer_size: 12000
   summary_freq: 1000
   time_horizon: 1000
   lambd: 0.99
   beta: 0.001

3DBallHard:
   normalize: true
   batch_size: 1200
   buffer_size: 12000
   summary_freq: 1000
   time_horizon: 1000
   max_steps: 5.0e5
   beta: 0.001
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.995

Tennis:
   normalize: true
   max_steps: 2e5

CrawlerStatic:
   normalize: true
   num_epoch: 3
   time_horizon: 1000
   batch_size: 2024
   buffer_size: 20240
   max_steps: 1e6
   summary_freq: 3000
   num_layers: 3
   hidden_units: 512
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.995

CrawlerDynamic:
   normalize: true
   num_epoch: 3
   time_horizon: 1000
   batch_size: 2024
   buffer_size: 20240
   max_steps: 1e6
   summary_freq: 3000
   num_layers: 3
   hidden_units: 512
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.995

Walker:
   normalize: true
   num_epoch: 3
   time_horizon: 1000
   batch_size: 2048
   buffer_size: 20480
   max_steps: 2e6
   summary_freq: 3000
   num_layers: 3
   hidden_units: 512
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.995

Reacher:
   normalize: true
   num_epoch: 3
   time_horizon: 1000
   batch_size: 2024
   buffer_size: 20240
   max_steps: 1e6
   summary_freq: 3000
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.995

Hallway:
   use_recurrent: true
   sequence_length: 64
   num_layers: 2
   hidden_units: 128
   memory_size: 256
   beta: 1.0e-2
   num_epoch: 3
   buffer_size: 1024
   batch_size: 128
   max_steps: 5.0e5
   summary_freq: 1000
   time_horizon: 64

VisualHallway:
   use_recurrent: true
   sequence_length: 64
   num_layers: 1
   hidden_units: 128
   memory_size: 256
   beta: 1.0e-2
   num_epoch: 3
   buffer_size: 1024
   batch_size: 64
   max_steps: 5.0e5
   summary_freq: 1000
   time_horizon: 64

VisualPushBlock:
   use_recurrent: true
   sequence_length: 32
   num_layers: 1
   hidden_units: 128
   memory_size: 256
   beta: 1.0e-2
   num_epoch: 3
   buffer_size: 1024
   batch_size: 64
   max_steps: 5.0e5
   summary_freq: 1000
   time_horizon: 64

GridWorld:
   batch_size: 32
   normalize: false
   num_layers: 1
   hidden_units: 256
   beta: 5.0e-3
   buffer_size: 256
   max_steps: 50000
   summary_freq: 2000
   time_horizon: 5
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.9

Basic:
   batch_size: 32
   normalize: false
   num_layers: 1
   hidden_units: 20
   beta: 5.0e-3
   buffer_size: 256
   max_steps: 5.0e5
   summary_freq: 2000
   time_horizon: 3
   reward_signals:
       extrinsic:
           strength: 1.0
           gamma: 0.9

PenguinLearning:
   summary_freq: 5000
   time_horizon: 128
   batch_size: 128
   buffer_size: 2048
   hidden_units: 256
   beta: 1.0e-2
   max_steps: 1.0e6

・「config → curricula → penguin」内のPenguinLearning.json

{
   "measure": "reward",
   "thresholds": [ -0.1, 0.7, 1.7, 1.7, 1.7, 2.7, 2.7 ],
   "min_lesson_length": 80,
   "signal_smoothing": true,
   "parameters": {
       "fish_speed": [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5 ],
       "feed_radius": [ 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.2 ]
   }
}


4.学習モデル

私がトレーニングした学習モデルを添付します。


5.さいごに

なにかうまくいかないことがあれば、ここのダウンロードデータを活用して、無事に進めてもらえてたらうれしいです。

この記事が気に入ったらサポートをしてみませんか?