前回はCartPoleのモデルの作成を行いました. ここではml-agentを使って実際にモデルを学習してみましょう.
前回:https://jp.magicode.io/hrintaro121/articles/4ae629d944b14547af4a51ed3738a3e2
ML-Agentsのインストールは以下を参考にインストールしてください. https://github.com/Unity-Technologies/ml-agents
モデルを学習するにあたって必要な手順は以下の通りです.
まず,下のようにProjectフォルダ内で右クリックし,create => C# scriptを選んでスクリプトを作成しましょう.
スクリプトの名前は適当で大丈夫です.ここではCartPoleAgentとしました.
作成されたスクリプトを開くとVScodeなどのエディタが立ち上がると思います. エディタ内では以下のようにコードを記述してください.
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using System.Collections;
using System.Collections.Generic;
public class CartPoleAgent : Agent
{
public GameObject pole;
Rigidbody poleRB;
Rigidbody cartRB;
EnvironmentParameters m_ResetParams;
public override void Initialize()
{
poleRB = pole.GetComponent<Rigidbody>();
cartRB = gameObject.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(gameObject.transform.localPosition.z);
sensor.AddObservation(cartRB.velocity.z);
sensor.AddObservation(pole.transform.localRotation.eulerAngles.x);
sensor.AddObservation(poleRB.angularVelocity.x);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
Vector3 controlSignal = Vector3.zero;
controlSignal.x = actionBuffers.ContinuousActions[0];
controlSignal.z = actionBuffers.ContinuousActions[1];
var actionZ = 200f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);
float cart_z = this.gameObject.transform.localPosition.z;
float angle_x = pole.transform.localRotation.eulerAngles.x;
if(180f < angle_x && angle_x < 360f)
{
angle_x = angle_x - 360f;
}
if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
{
SetReward(-1.0f);
EndEpisode();
}
else{
SetReward(0.1f);
}
if(cart_z < -10f || 10f < cart_z)
{
SetReward(-1.0f);
EndEpisode();
}
}
public override void OnActionReceived(float[] verctorAction)
{
var actionZ = 200f * Mathf.Clamp(verctorAction[0], -1f, 1f);
cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);
float cart_z = this.gameObject.transform.localPosition.z;
float angle_x = pole.transform.localRotation.eulerAngles.x;
if(180f < angle_x && angle_x < 360f)
{
angle_x = angle_x - 360f;
}
if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
{
SetReward(-1.0f);
EndEpisode();
}
else{
SetReward(0.1f);
}
if(cart_z < -10f || 10f < cart_z)
{
SetReward(-1.0f);
EndEpisode();
}
}
public override void OnEpisodeBegin()
{
gameObject.transform.localPosition = new Vector3(0f, 0f, 0f);
pole.transform.localPosition = new Vector3(0f, 2.5f, 0f);
pole.transform.localRotation = Quaternion.Euler(0f, 0f, 0f);
poleRB.angularVelocity = new Vector3(0f, 0f, 0f);
poleRB.velocity = new Vector3(0f, 0f, 0f);
poleRB.angularVelocity = new Vector3(Random.Range(-0.1f, 0.1f), 0f, 0f);
SetResetParameters();
}
public void SetPole()
{
poleRB.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
pole.transform.localScale = new Vector3(0.4f, 2f, 0.4f);
}
public void SetResetParameters()
{
SetPole();
}
}
簡単に上記のコードの説明をします.
学習のための初期設定を行います. 具体的にはスクリプト内の変数と実際のオブジェクトを紐づけることなどを行っています.
エージェントが学習に必要な情報をここで取得します. 今回使う情報は以下の通りです.
エージェントの行動・エピソードの終了判定・また報酬をエージェントに与えることなどをここで行います.
引数には,強化学習で学習したモデルが出力した行動をactionBuffers
として受け取ります.
学習のために,エピソードが始まるときの設定を行います. 具体的には,エージェントの位置をリセット,棒をランダムに傾ける,といったことを行います.
次に作成したスクリプトをエージェントに追加します.またinspectorのPoleという位置にhierarchyウィンドウからPoleオブジェクトをドラッグしてください.
またAdd Component から、Behavior Parameters と Decision Requester を追加しましょう.
各パラメータは以下のように設定してください.
学習に必要な準備は終了です.次は実際に学習してみましょう!
まず,configの設定が必要になります.
configの位置はダウンロードしたリポジトリのml-agents/config/ppo/
になります.
ML-Agentsのダウンロードは以下から行ってください.
https://github.com/Unity-Technologies/ml-agents
configの内容は以下のようにして作成してください.またファイルの名前は「2D_cartpole.yaml」としました.
2行目のCartPoleという名前はbehavior parametersのbehavior Nameと同じ名前にしましょう.
behaviors:
CartPole:
trainer_type: ppo
hyperparameters:
batch_size: 32
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 200000
time_horizon: 1000
summary_freq: 12000
あとはターミナルからコマンドを入力して学習を行います.
インストールしたリポジトリのml-agentsディレクトリから以下を実行しましょう.
mlagents-learn ../config/ppo/2D_cartpole.yaml --run-id=2Dcartpole --train
無事に実行できると以下のような画面が表示されます.
そして,出力されているようにUnity側で実行ボタン(矢印のマーク)を押すと実際に学習が始まります.
自分の環境では学習は20万ステップほどで完了し,棒を倒さずカートが動く様子が確認できました!