見出し画像

強化学習に必要な「Qテーブル」と「離散値で表す関数」をつくるには?

こんにちは!

ぷもんです。


前回、離散値ってなんや?というnoteで
強化学習で必要な離散値とは何か?なぜ必要なのか?
について書きました。

今回は具体的に離散値に変換していきます。


今回やるのはこちらのコードです。

q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))

def bins(clip_min, clip_max, num):
   return np.linspace(clip_min, clip_max, num + 1)[1:-1]

def digitize_state(observation):
   cart_pos, cart_v, pole_angle, pole_v = observation
   digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
                np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
                np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
                np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
   return sum([x * (4 ** i) for i, x in enumerate(digitized)])


やっていることは
・Qテーブルという表のようなものを作る
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
の2つです。

ではやって行きます!!


・Qテーブルってなんや?

強化学習の手法のうち、Q学習というものをやっています。

Q学習ではある時間tのある状態sである行動aを取った時どうなるかを
関数で表したQ関数というものを作ります。
Q関数を表で表したものがQテーブルです。

q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))

ではQテーブルを作っています。


Qテーブルという表みたいなものを作っているのはわかったけど
右辺の意味が気になります...。

・np.random.uniform(low, high, size)

np.random.uniform(low, high, size)は
low以上high未満の一様乱数をsize個つくるという意味です。

一様乱数はランダムな数という意味です。


今回の場合は

q_table = np.random.uniform(low=-1, high=1, size=(4 ** 4, env.action_space.n))

−1以上1未満の一様乱数を
(4 ** 4, env.action_space.n)個つくるという意味になります。


....。

sizeの内容が(4 ** 4, env.action_space.n)で意味不明です。


4 ** 4は4の4乗を表しており
前回、離散値ってなんや?で説明したように

・カート位置
・カート速度
・棒の角度
・棒の角速度
の4つの値を4つの領域に分けたことを表す4の4乗です。


env.action_space.nはこのゲームで、有効なactionを表していて
今回やっているCartPoleでは右に移動するか左に移動するかの2択なので
2になります。


ここまでわかると
縦が4の4乗の256、横が2の表が作られているのが
イメージできるのではないでしょうか?


やっと1行目を理解できました...。


ここまでで
・Qテーブルという表のようなものを作る
が終わりました。

ここからは
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
に入ります。


まずは

def bins(clip_min, clip_max, num):
   return np.linspace(clip_min, clip_max, num + 1)[1:-1]

です。


・defで関数を宣言!

Pythonではdefを使って
returnの後に書いた値を戻す
関数を宣言することができます。

今回の場合は
np.linspace(clip_min, clip_max, num + 1)[1:-1]という動きをする
bins(clip_min, clip_max, num)という関数を作っているのがわかります。


・np.linspace(始点, 終点, 何分割)[スライス]

np.linspaceを使うと等差数列がつくれます。

1、2、3、4、5、6、7、8、9、10は
公差が1の等差数列です。

np.linspaceに続く値では始点、終点、何分割するかを示すことができて
[]の値ではどれくらいの値を切り取るかを指定できます。

今回の場合

np.linspace(clip_min, clip_max, num + 1)[1:-1]

clip_minを始点、clip_maxを終点とするnum + 1分割した等差数列のうち
1〜−1を切り取るという意味になります。

clipは報酬の値を示すっぽいのですが
よくわからないので進めていくうちに理解できるのを待ちます。


つまり、ここでは

def bins(clip_min, clip_max, num):
   return np.linspace(clip_min, clip_max, num + 1)[1:-1]

clip_minを始点、clip_maxを終点とするnum + 1分割した等差数列のうち
1〜−1を切り取りとった値を戻す
bins(clip_min, clip_max, num)という関数を定義しました。


ラストがこちらです!

def digitize_state(observation):
   cart_pos, cart_v, pole_angle, pole_v = observation
   digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
                np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
                np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
                np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]
   return sum([x * (4 ** i) for i, x in enumerate(digitized)])


まずはdigitizedというリストをつくっています。

・observationをバラバラに

cart_pos, cart_v, pole_angle, pole_v = observation

observationは観測したデータを表していて
それをcart_pos, cart_v, pole_angle, pole_v
つまりカート位置、カート速度、棒の角度、棒の角速度に分けています。


・digitize関数って何や?

digitize関数ではある値がどの範囲に入るか?を求めることができます。

np.digitize(値, 範囲)を意味していて

np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4))

の場合cart_pos(=カート位置)が
-2.4,から2.4,を4つに分けた範囲のどの範囲に入るかを示します。
(先ほど作ったbins(clip_min, clip_max, num)が使われていますね!)

ちなみに4つの範囲のどれに入るのかは
0〜3の値で返してくれます。


他の3つも同じように働くとすると

digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
             np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
             np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
             np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]

は[0,1,1,3]のような0〜3が入るリストの形になることがわかります。


続いてreturnの中を見てみます。

・enumerate()

enumerate関数というものでforと一緒に使うことで
何番目に要素が入っているかと一緒に表示することができます。

list = ['1', '2', '3']
for i, name x enumerate(list)

この場合iが何番目かxがリストの中身を示すので
0,1
1,2
2,3
のように0番目が1というように示されます。

今回の場合はlistが

digitized = [np.digitize(cart_pos, bins=bins(-2.4, 2.4, 4)),
                np.digitize(cart_v, bins=bins(-3.0, 3.0, 4)),
                np.digitize(pole_angle, bins=bins(-0.5, 0.5, 4)),
                np.digitize(pole_v, bins=bins(-2.0, 2.0, 4))]

の部分になります。


これを踏まえて

return sum([x * (4 ** i) for i, x in enumerate(digitized)])

を見ていきます。


・sum([式 変数)])

sum関数は合計する関数なのですが
今回のように好きと変数が入っている場合は変わります。

今回の場合は
式=x * (4 ** i)
変数=for i, x in enumerate(digitized)
のようになっていて

digitizedのリストからi番目のxという要素かという変数を
x * (4 ** i)の式に代入したものの合計を返します。

この値は0〜255になって
4の4乗の256の離散値のどれかになっていることがわかります。


ここまでで
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
ことができ

・Qテーブルという表のようなものを作る
・フィードバックで得られる値を4の4乗の256の離散値で表す関数をつくる
の2つができました!


めっちゃむずかった...。

でも初めは絶対無理だと思っていた
複雑なコードも基本のコードの組み合わせで
一つずつ読んでいけば理解できることがわかりました。

めっちゃ根気が必要ですが...。


次はこの関数を使って
強化学習を入れるはずです!!


参考にしたサイトはこちらです。


最後まで読んでいただきありがとうございました。

ぷもんでした!

noteを日々投稿してます! もしいいなと思ってもらえたら サポートしてもらえるとありがたいです。 VRやパソコンの設備投資に使わせていただきます。 ご意見、質問等ありましたらコメントください。 #ぷもん でつぶやいてもらえると励みになります。 一緒に頑張りましょう!