Kers-rl2로 시작된 강화 학습(블록 분할)
45378 단어 PythonTensorFlowtech
개시하다
이번에keras-rl2로 그 얄미운 블록을 강화해서 배우고 싶어요.
이번 블록 분리 프로그램은 수업 시간에 처리된 물건을 파이톤으로 직접 옮기는 것이기 때문에 많은 오류가 있으니 무시하겠습니다.교수님께 불평하다.
메모니까 기대하지 마세요.
그럼 모델부터 만들어요.
모형을 제작하다
이번에 모델을 네트워크로 만들었어요.py에 기술하다.
from tensorflow.keras import models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, InputLayer
from tensorflow.keras.optimizers import Adam
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent
class Network():
def __init__(self, load):
self.model = Sequential([Flatten(input_shape=(1,15)),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(3,activation='linear')])
self.dqn_agent()
if(load):
self.load()
def dqn_agent(self):
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
self.dqn = DQNAgent(model=self.model,nb_actions=3,gamma=0.99,memory=memory,nb_steps_warmup=100,target_model_update=1e-2,policy=policy)
self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])
def fit(self,env,nb_steps):
self.dqn.fit(env,nb_steps=nb_steps,visualize=True,verbose=1)
def test(self,env):
self.dqn.test(env,nb_episodes=10,visualize=True)
def save(self):
self.model.save('weight/model.h5')
self.dqn.save_weights('dqn_weight',overwrite=True)
def load(self):
self.model = models.load_model('weight/model.h5')
self.dqn.load_weights('dqn_weight')
전체적으로 이런 느낌.이번에는 모델에 대한 입력 데이터로 10개의 블록이 아직도 존재합니까?공의 X, Y 좌표, 공의 x, y 방향의 속도, 공을 던진 라켓의 x 좌표를 전달한다.
따라서 입력 데이터는 15개다.
또 정지, 오른쪽, 왼쪽 세 개의 이동을 라켓으로 만들기 위해 데이터를 3개로 출력했다.
self.model = Sequential([Flatten(input_shape=(1,15)),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(16,activation='relu'),Dense(3,activation='linear')])
DQN의 설정은 이렇다. def dqn_agent(self):
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
self.dqn = DQNAgent(model=self.model,nb_actions=3,gamma=0.99,memory=memory,nb_steps_warmup=100,target_model_update=1e-2,policy=policy)
self.dqn.compile(Adam(lr=1e-3), metrics=['mae'])
gamma의 값을 줄이면 보수가 시간에 비해 감소량이 증가한다.또한 출력 데이터가 3개이기 때문에 nb액션스는 3.
게임 환경 조성
env.py에 기술하다
import gym
import numpy as np
from processing_py import *
import processing_py as pp
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, InputLayer
from rl.memory import SequentialMemory
from rl.policy import BoltzmannQPolicy
from rl.agents.dqn import DQNAgent
makernd = lambda a,b:np.int(np.random.random_sample()*(b-a)+a)
class MyEnv(gym.Env):
def __init__(self):
self.app = App(600,400)
self.reset()
self.actions=np.array([0,20,-20])
def reset(self):
self.ball_x_speed=5
if makernd(0,2)==0:
self.ball_x_speed=-5
self.ball_y_speed=5
self.blocks=np.array([1]*10)
self.score=0
self.racket_x=300
self.ball_x=makernd(0,590)
self.ball_y=makernd(110,150)
self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])
return self.observation
def step(self,action):
r=0.0
#ラケットが壁から出ないようにする
if self.racket_x<=0 and action==2:
pass
elif self.racket_x>=540 and action==1:
pass
else:
self.racket_x+=self.actions[action]
self.ball_x+=self.ball_x_speed
self.ball_y+=self.ball_y_speed
self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])
#ボールがラケットに当たったかどうか
if self.ball_x>=self.racket_x and self.ball_x<=self.racket_x+60 and self.ball_y<=350 and self.ball_y>=340:
self.ball_y_speed=-self.ball_y_speed
#ボールがブロックに当たったかどうか
for i in range(10):
if self.blocks[i]==1 and self.ball_x>=i*60 and self.ball_x<=i*60+60 and self.ball_y>=50 and self.ball_y<=110:
self.blocks[i]=0
self.ball_y_speed=-self.ball_y_speed
r+=5
###
if self.ball_x<0 or self.ball_x+10>600:
self.ball_x_speed=-self.ball_x_speed
if self.ball_y<=0:
self.ball_y_speed=-self.ball_y_speed
###
if self.ball_y+10>=400:
return self.observation,np.float32(-50),True,{} #失敗
elif self.is_game_clear():
return self.observation,np.float32(r+100),True,{} #成功
else:
if abs(self.ball_x-self.racket_x-25)>=30:
r-=0.1
else:
r+=0.1
return self.observation,np.float32(r),False,{}
#ブロックが全部消えたかどうかを調べる
def is_game_clear(self):
for i in range(10):
if self.blocks[i]==1:
return False
return True
def render(self,mode):
self.app.background(10,10,10)
self.draw_blocks()
self.draw_racket()
self.draw_ball()
self.app.redraw()
def draw_blocks(self):
for i in range(10):
if self.blocks[i]==1:
self.app.rect(i*60,50,60,60)
def draw_racket(self):
self.app.rect(self.racket_x,350,60,10)
def draw_ball(self):
self.app.circle(self.ball_x,self.ball_y,10)
step는 처리를,render는 화면을 담당한다.observation에서 15개의 입력 데이터를 책임지십시오.
self.observation=np.hstack([self.blocks,self.ball_x,self.ball_y,self.ball_x_speed,self.ball_y_speed,self.racket_x])
step에서 게임 종료 후 진리치로 반환할지 여부는 그 외에 보수를 지불해야 한다.return self.observation,np.float32(r),False,{}
또한 이번 보수는 라켓의 x 좌표가 공의 x 좌표에서 너무 멀거나 게임이 끝난 후에 줄어든다.반대로 가호는 라켓의 x좌표가 공에 가까운 x좌표, 게임 통관, 아니면 공이 공을 명중시키는 것이다.
for i in range(10):
if self.blocks[i]==1 and self.ball_x>=i*60 and self.ball_x<=i*60+60 and self.ball_y>=50 and self.ball_y<=110:
self.blocks[i]=0
self.ball_y_speed=-self.ball_y_speed
r+=5
###
if self.ball_x<0 or self.ball_x+10>600:
self.ball_x_speed=-self.ball_x_speed
if self.ball_y<=0:
self.ball_y_speed=-self.ball_y_speed
###
if self.ball_y+10>=400:
return self.observation,np.float32(-50),True,{} #失敗
elif self.is_game_clear():
return self.observation,np.float32(r+100),True,{} #成功
else:
if abs(self.ball_x-self.racket_x-25)>=30:
r-=0.1
else:
r+=0.1
return self.observation,np.float32(r),False,{}
render는processingpy를 사용합니다.(참조: https://pypi.org/project/processing-py/
main.py로 정리하다
from env import MyEnv
from network import Network
if __name__=='__main__':
env=MyEnv()
env.reset()
net=Network(False)
net.fit(env,100000)
net.test(env)
net.save()
이상 종료.
Reference
이 문제에 관하여(Kers-rl2로 시작된 강화 학습(블록 분할)), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://zenn.dev/antman/articles/64ff6ae2b302cc텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)