ロゴ ロゴ

【強化学習】コピペでスーパーマリオをクリアしよう!【PyTorch】

強化学習って聞いたことあるけど…

皆さんは強化学習というものを知っていますか?聞いたことはある、Youtubeで物理エンジンを使った学習をしている人を見たことがあるという人はこのページを見ている人の中には多いかもしれません。しかし、いざ自分が作ってみようとなった時にどうやって作ったらいいかわからない人も多いと思います(自分もそうだった)。

何がわかんない?

実際に作ってみようにも正直説明が意味わからないなと思ったことはありませんか?自分はその典型例で画像の判別やオセロのAI、はたまた研究室ではセマンティック セグメンテーションを利用した研究をしていたりとDeep Learningに触れていてもなお意味が理解できていません。。。。

というのも、報酬がより多くなるような動きを行うように学習というのがどのように行われているのかに関するわかりやすい説明がない(今まで見かけたことが無い)からです。一般的には学習は入力と出力が近づくように学習するという仕組みである以上、そのペア、またはそれに準ずる何かを生成できないといけません。教師なし学習の有名な例としてはGANがありますが、あれも生成器の出力に対しての評価を判別器による一定の出力(本物である確率)において達成しています。それに対し、報酬型の強化学習は報酬の上限がいくつかというのは設定する報酬やその時の状態に応じて変わってしまいます。要するに最大値がわからないことがあるのです。

でも動くものが見たい

ということで、いろいろとぐちぐち言ってきましたが、とりあえず動くものを見た方がモチベも上がるし面白いよなと思ったのでコピペで何とかできないかなと思ったのが今回の主題です。

さっそく今回のソースコードの元なのですが、PyTorchの公式チュートリアルにこんなものがありました。ページを見てもらえばわかりますが、ファミコンのマリオの学習ができるチュートリアルです。

さっそく動かしていく

とりあえず今回実際に自分が動かした環境です。これに当てはまらない人用もあるのでその場合は飛ばしてみてください。

  • CPU : Intel Core i7-10700K
  • GPU : NVIDIA RTX3060ti 8GB
  • RAM : 32GB
  • OS : Windows10 Pro
  • Python : 3.7.9

と雑に書いておきます。とりあえずPythonはAnacondaはダメです(最重要)。多分動くかもしれませんが自分がAnaconda嫌いなので動かなくても知りません。他にはNVIDIAのGPUがあるよという人でなければ次の説明のやり方はお勧めしません。AMDのRadeonではダメです(GPU推論はできません)。ちなみに、使うGPUもそれなりのものでないとちょっと厳しい気はします。2000番台なら70番以上、3000番台なら60番以上、RTX Aシリーズなら2000でも大丈夫なのかな?4000なら問題ないです。GPUメモリはある程度(8GB以上)は欲しいよといった感じです。

PyTorchのインストール

Pytorchのインストール、GPUで利用するまでの話に関しては様々なサイトに書いてあるので特にCUDAの周りは書くのが大変なので省略させてもらいます。ここで挫折した方はさらに下に書いた方法を使ってください。できればPython純正の仮想環境を利用しましょう。

その他ライブラリのインストール

他に利用するものは以下のものになります。(順番にコマンドを実行してね)

pip install gym
pip install nes_py
pip install gym_super_mario_bros
pip install matplotlib

これらをインストール出来たら準備完了です。

ソースコードをコピー

次に実行するソースコードについてですが、こちらをコピーして実行するファイルにペーストしてください。しっかりと上記の環境ができていればGPU(またはCPU)で実行されるはずです。

実行したらマリオがランダムに動いて見えたら成功です。だんだん学習していくと思います。

報酬等を変更したりしたい場合は以下に解説を載せているので気になる方は見てみてください。

パソコンがよわよわ、AMDのGPUを使ってる人向け

ノートパソコンの場合はGoogle Colaboratoryを利用するのが手っ取り早いです。とりあえずGoogle Drive等からColabのファイルを作成しましょう。開けたら上の方にある「ランタイム」>「ランタイムのタイプの変更」を選択して、ハードウェアアクセラレータをGPUに変更して保存を押します。

その後こちらのソースコードをコピーして実行します。初期設定のままだとエピソードが100の倍数の時にvideoのフォルダに動画が保存されます。

ソースコードの簡易解説

ここからはソースコードの部分的な解説を行っています。

報酬

まずは報酬の部分についてです。関係する部分はこちら。

class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def reset(self, **kwargs):
        self._cur_v = 0
        self._cur_x = 0
        self._max_x = 0
        self.reward = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        self.reward = 0
        done = False
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)

            if info['x_pos'] > self._cur_x:
                self.reward += 1
            else:
                self.reward -= 2

            if info['x_pos'] - self._cur_x >= self._cur_v:
                self.reward += 1
            else:
                self.reward -= 0

            if info['x_pos'] > self._max_x:
                self.reward += 1

            if info['flag_get']:
                self.reward += 10

            self._cur_v = info['x_pos'] - self._cur_x
            self._cur_x = info['x_pos']
            if self._max_x < info['x_pos']:
                self._max_x = info['x_pos']

            if done:
                self.reward -= 10
                break
        self.reward /= self._skip
        return obs, self.reward, done, info

step()関数が実行されると報酬が出力されます。なので、ここで報酬が決定されるのですが、現在の状況はinfoの中に格納されているので、そこに入っている情報から報酬を決めます。

infoの中には

  • coins : コインの枚数
  • flag_get : クリアか否かのBool
  • life : 残り残基
  • score : スコア
  • stage : ステージ番号
  • status : マリオの状態(small, tall, fireball)
  • time : 残り時間(ステージ1, 2は最大400[s], 3, 4は最大300[s], 特殊仕様で8-1は最大300[s])
  • world : ワールドナンバー
  • x_pos : マリオのx座標
  • y_pos : マリオのy座標

の情報が入っており、info['coins']のような形で取得します。

学習処理部分

次に学習処理部分についてです。

env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")

まずは環境の読み込みです。SuperMarioBros-1-1-v0の部分に関してなのですが、後半の-1-1の部分が1-1のマップを読み込むことを表しています。そのあとの-v0の部分は読み込むROMの種類が変わります。

  • v0 : 通常ROM
  • v1 : ダウンサンプルROM
  • v2 : ピクセル(もっと荒い)
  • v3 : レクトアングル(四角)

加えて、SuperMarioBrosの部分をSuperMarioBros2に変えるとマリオブラザーズ2で学習ができます(v3, v4はなし)。同様にSuperMarioBros-1-1の部分をSuperMarioBrosRandomStagesに変更すればステージをランダムに実行してくれます。

env = JoypadSpace(env, [['NOOP'], ["right"], ["right", "A"],["right", "B"], ["right", "A", "B"],["left"], ["left", "A"],["left", "B"], ["left", "A", "B"],["right","left"], ["right","left", "A"],["right","left", "B"], ["right","left", "A", "B"],["down"]])

次に、操作部分に関してです。操作の種類を上記のようにリスト形式で記述します。種類は以下の通りです。

  • NOOP : 何も入力なし
  • right : 右入力
  • left : 左入力
  • up : 上入力
  • down : 下入力
  • A : Aボタン入力
  • B : Bボタン入力
env = SkipFrame(env, skip=1)

次に上記のskip=1の部分を変えると一回の推論で実行されるフレーム数が変更されます。数字を大きくすると学習が早くなる(?)のかもしれませんが、小回りが利かなくなります。

chkpt_file=None
mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, chkpt_file=chkpt_file)

chkpt_file=Noneに関しては実行していると定期的に生成されるchkptファイルのパスを文字列で指定すると読み込んでくれます。学習が途中で終了した場合は最新のものを読み込むことで継続学習が可能になります。chkptファイルの保存タイミングはMarioクラスの__init__関数内にself.save_everyで指定されています(1回の保存までのステップ数)。

episodes = 100
for e in range(episodes):
    state = env.reset()
    while True:
        action = mario.act(state)
        next_state, reward, done, info = env.step(action)
        mario.cache(state, next_state, action, reward, done)
        q, loss = mario.learn()
        logger.log_step(reward, loss, q)
        state = next_state
        env.render()
        if done or info["flag_get"]:
            break

    logger.log_episode()
    if e % 1 == 0:
        logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)

episodes = 100は学習エピソード数の合計になります。この数だと学習回数が少なすぎてchkptファイルは保存されないはずです。

env.render()で実行されている状況を出力(新規ウィンドウ)しているのですが、Colabでは利用できないので代わりに次のものを利用しています。

video_every = 100
env = gym.wrappers.Monitor(env, "./video", video_callable=lambda episode_id: (episode_id%video_every)==0, force=True)

これをepisodes = 100よりも前の行に挟んでおくことで一定のエピソードごとに動画がmp4形式で保存されます。video_every = 100となっているので、100エピソードごとに動画が保存され、"./video"フォルダに保存されます(フォルダがない場合はフォルダも作成されます)。

まとめ

ということで、今回はとりあえず強化学習をコピペで動かすことを目的にした人に向けたものを書いてみたのですがどうだったでしょうか?基本的なPythonの利用方法を知らないとこの説明でも意味が分からないかもしれませんが、少しでもこれを見た人がわかってくれたらうれしいです。自分もこれを動かしながら他のものでもできるように勉強していこうかなと思いました。

最後に学習してゴールまで行った動画を載せておきます。

以下はコピペ用のソースコードになります。

ソースコード

左上の方にカーソルを合わせるとコピーと出ると思うのでクリックするとコピーできます。

import torch
from torch import nn
from torchvision import transforms as T
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
import numpy as np
import time, datetime
import matplotlib.pyplot as plt

class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def reset(self, **kwargs):
        self._cur_v = 0
        self._cur_x = 0
        self._max_x = 0
        self.reward = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        self.reward = 0
        done = False
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)

            if info['x_pos'] > self._cur_x:
                self.reward += 1
            else:
                self.reward -= 2

            if info['x_pos'] - self._cur_x >= self._cur_v:
                self.reward += 1
            else:
                self.reward -= 0

            if info['x_pos'] > self._max_x:
                self.reward += 1

            if info['flag_get']:
                self.reward += 10

            self._cur_v = info['x_pos'] - self._cur_x
            self._cur_x = info['x_pos']
            if self._max_x < info['x_pos']:
                self._max_x = info['x_pos']

            if done:
                self.reward -= 10
                break
        self.reward /= self._skip
        return obs, self.reward, done, info

class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation

class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation

class Mario:
    def __init__(self, state_dim, action_dim, save_dir, chkpt_file=None):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.use_cuda = torch.cuda.is_available()

        self.net = MarioNet(self.state_dim, self.action_dim).float()
        if chkpt_file!=None:
            checkpoint = torch.load(chkpt_file)
            self.net.load_state_dict(checkpoint['model'])
        if self.use_cuda:
            self.net = self.net.to(device="cuda")

        self.exploration_rate = 1
        if chkpt_file!=None:
            self.exploration_rate = checkpoint['exploration_rate']
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.save_every = 5e5

        self.memory = deque(maxlen=3000)#GPUメモリ容量に余裕があれば大きくしてよい
        self.batch_size = 32

        self.gamma = 0.9

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = torch.nn.SmoothL1Loss()

        self.burnin = 1e4  # min. experiences before training
        self.learn_every = 3  # no. of experiences between updates to Q_online
        self.sync_every = 1e4  # no. of experiences between Q_target & Q_online sync

    def act(self, state):
        # EXPLORE
        if np.random.rand() < self.exploration_rate:
            action_idx = np.random.randint(self.action_dim)

        # EXPLOIT
        else:
            state = state.__array__()
            if self.use_cuda:
                state = torch.tensor(state).cuda()
            else:
                state = torch.tensor(state)
            state = state.unsqueeze(0)
            action_values = self.net(state, model="online")
            action_idx = torch.argmax(action_values, axis=1).item()

        # decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # increment step
        self.curr_step += 1
        return action_idx

    def cache(self, state, next_state, action, reward, done):
        state = state.__array__()
        next_state = next_state.__array__()

        if self.use_cuda:
            state = torch.tensor(state).cuda()
            next_state = torch.tensor(next_state).cuda()
            action = torch.tensor([action]).cuda()
            reward = torch.tensor([reward]).cuda()
            done = torch.tensor([done]).cuda()
        else:
            state = torch.tensor(state)
            next_state = torch.tensor(next_state)
            action = torch.tensor([action])
            reward = torch.tensor([reward])
            done = torch.tensor([done])

        self.memory.append((state, next_state, action, reward, done,))

    def recall(self):
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()
        if self.curr_step % self.save_every == 0:
            self.save()
        if self.curr_step < self.burnin:
            return None, None
        if self.curr_step % self.learn_every != 0:
            return None, None
        state, next_state, action, reward, done = self.recall()
        td_est = self.td_estimate(state, action)
        td_tgt = self.td_target(reward, next_state, done)
        loss = self.update_Q_online(td_est, td_tgt)
        return (td_est.mean().item(), loss)

    def td_estimate(self, state, action):
        current_Q = self.net(state, model="online")[
            np.arange(0, self.batch_size), action
        ]
        return current_Q

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.net(next_state, model="online")
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.net(next_state, model="target")[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())

    def save(self):
        save_path = (
            self.save_dir / f"mario_net_{int(self.curr_step // self.save_every)}.chkpt"
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
            save_path,
        )
        print(f"MarioNet saved to {save_path} at step {self.curr_step}")

    def act_only_AI(self, state):
        state = state.__array__()
        if self.use_cuda:
            state = torch.tensor(state).cuda()
        else:
            state = torch.tensor(state)
        state = state.unsqueeze(0)
        action_values = self.net(state, model="online")
        action_idx = torch.argmax(action_values, axis=1).item()
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
        self.curr_step += 1
        return action_idx

class MarioNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        c, h, w = input_dim
        if h != 84:
            raise ValueError(f"Expecting input height: 84, got: {h}")
        if w != 84:
            raise ValueError(f"Expecting input width: 84, got: {w}")
        self.online = nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )
        self.target = copy.deepcopy(self.online)
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)

class MetricLogger:
    def __init__(self, save_dir):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")
env = JoypadSpace(env, [['NOOP'], ["right"], ["right", "A"],["right", "B"], ["right", "A", "B"],["left"], ["left", "A"],["left", "B"], ["left", "A", "B"],["right","left"], ["right","left", "A"],["right","left", "B"], ["right","left", "A", "B"],["down"]])
env = SkipFrame(env, skip=1)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
env = FrameStack(env, num_stack=4)
use_cuda = torch.cuda.is_available()
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
chkpt_file=None
mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, chkpt_file=chkpt_file)
logger = MetricLogger(save_dir)

episodes = 100
for e in range(episodes):
    state = env.reset()
    while True:
        action = mario.act(state)
        next_state, reward, done, info = env.step(action)
        mario.cache(state, next_state, action, reward, done)
        q, loss = mario.learn()
        logger.log_step(reward, loss, q)
        state = next_state
        env.render()
        if done or info["flag_get"]:
            break

    logger.log_episode()
    if e % 1 == 0:
        logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)

ソースコード(Colab用)

左上の方にカーソルを合わせるとコピーと出ると思うのでクリックするとコピーできます。

!pip install nes_py
!pip install gym_super_mario_bros
import torch
from torch import nn
from torchvision import transforms as T
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
import numpy as np
import time, datetime
import matplotlib.pyplot as plt

class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def reset(self, **kwargs):
        self._cur_v = 0
        self._cur_x = 0
        self._max_x = 0
        self.reward = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        self.reward = 0
        done = False
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)

            if info['x_pos'] > self._cur_x:
                self.reward += 1
            else:
                self.reward -= 2

            if info['x_pos'] - self._cur_x >= self._cur_v:
                self.reward += 1
            else:
                self.reward -= 0

            if info['x_pos'] > self._max_x:
                self.reward += 1

            if info['flag_get']:
                self.reward += 10

            self._cur_v = info['x_pos'] - self._cur_x
            self._cur_x = info['x_pos']
            if self._max_x < info['x_pos']:
                self._max_x = info['x_pos']

            if done:
                self.reward -= 10
                break
        self.reward /= self._skip
        return obs, self.reward, done, info

class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation

class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation

class Mario:
    def __init__(self, state_dim, action_dim, save_dir, chkpt_file=None):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.use_cuda = torch.cuda.is_available()

        self.net = MarioNet(self.state_dim, self.action_dim).float()
        if chkpt_file!=None:
            checkpoint = torch.load(chkpt_file)
            self.net.load_state_dict(checkpoint['model'])
        if self.use_cuda:
            self.net = self.net.to(device="cuda")

        self.exploration_rate = 1
        if chkpt_file!=None:
            self.exploration_rate = checkpoint['exploration_rate']
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.save_every = 5e5

        self.memory = deque(maxlen=3000)#GPUメモリ容量に余裕があれば大きくしてよい
        self.batch_size = 32

        self.gamma = 0.9

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = torch.nn.SmoothL1Loss()

        self.burnin = 1e4  # min. experiences before training
        self.learn_every = 3  # no. of experiences between updates to Q_online
        self.sync_every = 1e4  # no. of experiences between Q_target & Q_online sync

    def act(self, state):
        # EXPLORE
        if np.random.rand() < self.exploration_rate:
            action_idx = np.random.randint(self.action_dim)

        # EXPLOIT
        else:
            state = state.__array__()
            if self.use_cuda:
                state = torch.tensor(state).cuda()
            else:
                state = torch.tensor(state)
            state = state.unsqueeze(0)
            action_values = self.net(state, model="online")
            action_idx = torch.argmax(action_values, axis=1).item()

        # decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # increment step
        self.curr_step += 1
        return action_idx

    def cache(self, state, next_state, action, reward, done):
        state = state.__array__()
        next_state = next_state.__array__()

        if self.use_cuda:
            state = torch.tensor(state).cuda()
            next_state = torch.tensor(next_state).cuda()
            action = torch.tensor([action]).cuda()
            reward = torch.tensor([reward]).cuda()
            done = torch.tensor([done]).cuda()
        else:
            state = torch.tensor(state)
            next_state = torch.tensor(next_state)
            action = torch.tensor([action])
            reward = torch.tensor([reward])
            done = torch.tensor([done])

        self.memory.append((state, next_state, action, reward, done,))

    def recall(self):
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()
        if self.curr_step % self.save_every == 0:
            self.save()
        if self.curr_step < self.burnin:
            return None, None
        if self.curr_step % self.learn_every != 0:
            return None, None
        state, next_state, action, reward, done = self.recall()
        td_est = self.td_estimate(state, action)
        td_tgt = self.td_target(reward, next_state, done)
        loss = self.update_Q_online(td_est, td_tgt)
        return (td_est.mean().item(), loss)

    def td_estimate(self, state, action):
        current_Q = self.net(state, model="online")[
            np.arange(0, self.batch_size), action
        ]
        return current_Q

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.net(next_state, model="online")
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.net(next_state, model="target")[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())

    def save(self):
        save_path = (
            self.save_dir / f"mario_net_{int(self.curr_step // self.save_every)}.chkpt"
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
            save_path,
        )
        print(f"MarioNet saved to {save_path} at step {self.curr_step}")

    def act_only_AI(self, state):
        state = state.__array__()
        if self.use_cuda:
            state = torch.tensor(state).cuda()
        else:
            state = torch.tensor(state)
        state = state.unsqueeze(0)
        action_values = self.net(state, model="online")
        action_idx = torch.argmax(action_values, axis=1).item()
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
        self.curr_step += 1
        return action_idx

class MarioNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        c, h, w = input_dim
        if h != 84:
            raise ValueError(f"Expecting input height: 84, got: {h}")
        if w != 84:
            raise ValueError(f"Expecting input width: 84, got: {w}")
        self.online = nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )
        self.target = copy.deepcopy(self.online)
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)

class MetricLogger:
    def __init__(self, save_dir):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")
env = JoypadSpace(env, [['NOOP'], ["right"], ["right", "A"],["right", "B"], ["right", "A", "B"],["left"], ["left", "A"],["left", "B"], ["left", "A", "B"],["right","left"], ["right","left", "A"],["right","left", "B"], ["right","left", "A", "B"],["down"]])
env = SkipFrame(env, skip=1)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
env = FrameStack(env, num_stack=4)
video_every = 100
env = gym.wrappers.Monitor(env, "./video", video_callable=lambda episode_id: (episode_id%video_every)==0, force=True)
use_cuda = torch.cuda.is_available()
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
chkpt_file=None
mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, chkpt_file=chkpt_file)
logger = MetricLogger(save_dir)

episodes = 20
for e in range(episodes):
    state = env.reset()
    while True:
        action = mario.act(state)
        next_state, reward, done, info = env.step(action)
        mario.cache(state, next_state, action, reward, done)
        q, loss = mario.learn()
        logger.log_step(reward, loss, q)
        state = next_state
        if done or info["flag_get"]:
            break

    logger.log_episode()
    if e % 1 == 0:
        logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)

コメント入力

関連サイト