chainer의 작법 그 7

5092 단어 GymChainer

개요



chainaer의 작법을 조사해 보았다.
open-ai gym 해봤다.

사진







샘플 코드


import collections
import copy
import random
import gym
import numpy as np
import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers, Chain, no_backprop_mode
import matplotlib.pyplot as plt


class QFunction(Chain):
    def __init__(self, obs_size, n_actions, n_units = 100):
        super(QFunction, self).__init__(l0 = L.Linear(obs_size, n_units), l1 = L.Linear(n_units, n_units), l2 = L.Linear(n_units, n_actions))
    def __call__(self, x):
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return self.l2(h)

def get_greedy_action(Q, obs):
    obs = Q.xp.asarray(obs[None], dtype = np.float32)
    with no_backprop_mode():
        q = Q(obs).data[0]
    return int(q.argmax())

def mean_clipped_loss(y, t):
    return F.mean(F.huber_loss(y, t, delta = 1.0, reduce = 'no'))

def update(Q, target_Q, opt, samples, gamma = 0.99, target_type = 'double_dqn'):
    xp = Q.xp
    obs = xp.asarray([sample[0] for sample in samples], dtype = np.float32)
    action = xp.asarray([sample[1] for sample in samples], dtype = np.int32)
    reward = xp.asarray([sample[2] for sample in samples], dtype = np.float32)
    done = xp.asarray([sample[3] for sample in samples], dtype = np.float32)
    obs_next = xp.asarray([sample[4] for sample in samples], dtype = np.float32)
    y = F.select_item(Q(obs), action)
    with no_backprop_mode():
        if target_type == 'dqn':
            next_q = F.max(target_Q(obs_next), axis = 1)
        elif target_type == 'double_dqn':
            next_q = F.select_item(target_Q(obs_next), F.argmax(Q(obs_next), axis = 1))
        else:
            raise ValueError('Unsupported target_type: {}'.format(target_type))
        target = reward + gamma * (1 - done) * next_q
    loss = mean_clipped_loss(y, target)
    Q.cleargrads()
    loss.backward()
    opt.update()

def main():
    env = gym.make('CartPole-v0')
    assert isinstance(env.observation_space, gym.spaces.Box)
    assert isinstance(env.action_space, gym.spaces.Discrete)
    obs_size = env.observation_space.low.size
    n_actions = env.action_space.n
    reward_threshold = env.spec.reward_threshold
    print (obs_size, n_actions, reward_threshold)
    if reward_threshold is not None:
        print ('{} defines "solving" as getting average reward of {} over 100 ' 'consecutive trials.'.format('CartPole-v0', reward_threshold))
    else:
        print ('{} is an unsolved environment, which means it does not have a ' 'specified reward threshold at which it\'s considered ' 'solved.'.format('CartPole-v0'))
    D = collections.deque(maxlen = 10 ** 6)
    Rs = collections.deque(maxlen = 100)
    iteration = 0
    Q = QFunction(obs_size, n_actions, n_units = 100)
    target_Q = copy.deepcopy(Q)
    opt = optimizers.Adam(eps = 1e-2)
    opt.setup(Q)
    reward_trend = []
    for episode in range(200):
        obs = env.reset()
        done = False
        R = 0.0
        timestep = 0
        while not done and timestep < env.spec.timestep_limit:
            env.render()
            epsilon = 1.0 if len(D) < 500 else max(0.01, np.interp(iteration, [0, 5000], [1.0, 0.01]))
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = get_greedy_action(Q, obs)
            new_obs, reward, done, _ = env.step(action)
            R += reward
            D.append((obs, action, reward * 1e-2, done, new_obs))
            obs = new_obs
            if len(D) >= 500:
                sample_indices = random.sample(range(len(D)), 64)
                samples = [D[i] for i in sample_indices]
                update(Q, target_Q, opt, samples, target_type = 'dqn')
            if iteration % 100 == 0:
                target_Q = copy.deepcopy(Q)
            iteration += 1
            timestep += 1
        Rs.append(R)
        reward_trend.append(R)
        average_R = np.mean(Rs)
        print ('episode: {} iteration: {} reward: {} average_R: {}'.format(episode, iteration, R, average_R))
        if reward_threshold is not None and average_R >= reward_threshold:
            print ('Solved {} by getting average reward of ' '{} >= {} over 100 consecutive episodes.'.format('CartPole-v0', average_R, reward_threshold))
            break
    plt.plot(reward_trend)
    plt.savefig("cart0.png")
    plt.show()

if __name__ == '__main__':
    main()




이상.

좋은 웹페이지 즐겨찾기