[파이톤] LSTM 소라 아가씨에 대한 가위바위보 예측.
카탈로그
개시하다
며칠 전 미스 소라 가위바위보 연구소가 과거 미스 소라의 가위바위보 손 PDF를 공개해 화제를 모았다.그리고 이 연구소는 이 데이터를 분석해 전체 승률이 70% 이상이라는 것을 알게 됐고, 아직 성숙하지는 않지만 기계를 뜯어먹고 공부하는 사람으로서 가만히 있을 수는 없다는 것을 알게 됐다.
이에 따라 우선 LSTM을 이용해 예측 모델을 간단하게 제작하기로 했다.
소스 코드
부기 사항 "다변수 LSTM"을 참고하십시오.import numpy as np
import pandas as pd
from keras.layers import LSTM, Activation, Dense
from keras.models import Sequential
data_file = 'サザエさんじゃんけん.tsv'
look_back = 13 # 遡る時間
res_file = 'lstm'
def shuffle_lists(list1, list2):
'''リストをまとめてシャッフル'''
seed = np.random.randint(0, 1000)
np.random.seed(seed)
np.random.shuffle(list1)
np.random.seed(seed)
np.random.shuffle(list2)
def get_data():
'''データ作成'''
df = pd.read_csv(data_file, sep='\t',
usecols=['rock', 'scissors', 'paper'])
dataset = df.values.astype(np.float32)
X_data, y_data = [], []
for i in range(len(dataset) - look_back - 1):
x = dataset[i:(i + look_back)]
X_data.append(x)
y_data.append(dataset[i + look_back])
# X_data = np.array(X_data)
# y_data = np.array(y_data)
X_data = np.array(X_data[-500:])
y_data = np.array(y_data[-500:])
last_data = np.array([dataset[-look_back:]])
# シャッフル
shuffle_lists(X_data, y_data)
return X_data, y_data, last_data
def get_model():
model = Sequential()
model.add(LSTM(16, input_shape=(look_back, 3)))
model.add(Dense(3))
model.add(Activation('softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def pred(model, X, Y, label):
'''正解率 出力'''
predictX = model.predict(X)
correct = 0
for real, predict in zip(Y, predictX):
if real.argmax() == predict.argmax():
correct += 1
correct = correct / len(Y)
print(label + '正解率 : %02.2f ' % correct)
def main():
# データ取得
X_data, y_data, last_data = get_data()
# データ分割
mid = int(len(X_data) * 0.7)
train_X, train_y = X_data[:mid], y_data[:mid]
test_X, test_y = X_data[mid:], y_data[mid:]
# 学習
model = get_model()
hist = model.fit(train_X, train_y, epochs=50, batch_size=16,
validation_data=(test_X, test_y))
# 正解率出力
pred(model, train_X, train_y, 'train')
pred(model, test_X, test_y, 'test')
# 来週の手
next_hand = model.predict(last_data)
print(next_hand[0])
hands = ['グー', 'チョキ', 'パー']
print('来週の手 : ' + hands[next_hand[0].argmax()])
if __name__ == '__main__':
main()
간단히 설명하다
데이터는 tsv 파일에 미리 집합됩니다.year month day rock scissors paper
1991 11 10 0 1 0
1991 11 17 1 0 0
1991 11 24 1 0 0
1991 12 1 0 0 1
...
그리고 데이터를 다음과 같이 정형한다(look back = 2시)[[[0. 0. 1.]
[0. 1. 0.]]
[[1. 0. 0.]
[0. 1. 0.]]
[[1. 0. 0.]
[1. 0. 0.]]
...
결실
매개 변수를 바꾸어 보아라. 가장 좋은 결과는 바로 이렇다.
부기 사항 "다변수 LSTM"을 참고하십시오.
import numpy as np
import pandas as pd
from keras.layers import LSTM, Activation, Dense
from keras.models import Sequential
data_file = 'サザエさんじゃんけん.tsv'
look_back = 13 # 遡る時間
res_file = 'lstm'
def shuffle_lists(list1, list2):
'''リストをまとめてシャッフル'''
seed = np.random.randint(0, 1000)
np.random.seed(seed)
np.random.shuffle(list1)
np.random.seed(seed)
np.random.shuffle(list2)
def get_data():
'''データ作成'''
df = pd.read_csv(data_file, sep='\t',
usecols=['rock', 'scissors', 'paper'])
dataset = df.values.astype(np.float32)
X_data, y_data = [], []
for i in range(len(dataset) - look_back - 1):
x = dataset[i:(i + look_back)]
X_data.append(x)
y_data.append(dataset[i + look_back])
# X_data = np.array(X_data)
# y_data = np.array(y_data)
X_data = np.array(X_data[-500:])
y_data = np.array(y_data[-500:])
last_data = np.array([dataset[-look_back:]])
# シャッフル
shuffle_lists(X_data, y_data)
return X_data, y_data, last_data
def get_model():
model = Sequential()
model.add(LSTM(16, input_shape=(look_back, 3)))
model.add(Dense(3))
model.add(Activation('softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def pred(model, X, Y, label):
'''正解率 出力'''
predictX = model.predict(X)
correct = 0
for real, predict in zip(Y, predictX):
if real.argmax() == predict.argmax():
correct += 1
correct = correct / len(Y)
print(label + '正解率 : %02.2f ' % correct)
def main():
# データ取得
X_data, y_data, last_data = get_data()
# データ分割
mid = int(len(X_data) * 0.7)
train_X, train_y = X_data[:mid], y_data[:mid]
test_X, test_y = X_data[mid:], y_data[mid:]
# 学習
model = get_model()
hist = model.fit(train_X, train_y, epochs=50, batch_size=16,
validation_data=(test_X, test_y))
# 正解率出力
pred(model, train_X, train_y, 'train')
pred(model, test_X, test_y, 'test')
# 来週の手
next_hand = model.predict(last_data)
print(next_hand[0])
hands = ['グー', 'チョキ', 'パー']
print('来週の手 : ' + hands[next_hand[0].argmax()])
if __name__ == '__main__':
main()
간단히 설명하다
데이터는 tsv 파일에 미리 집합됩니다.year month day rock scissors paper
1991 11 10 0 1 0
1991 11 17 1 0 0
1991 11 24 1 0 0
1991 12 1 0 0 1
...
그리고 데이터를 다음과 같이 정형한다(look back = 2시)[[[0. 0. 1.]
[0. 1. 0.]]
[[1. 0. 0.]
[0. 1. 0.]]
[[1. 0. 0.]
[1. 0. 0.]]
...
결실
매개 변수를 바꾸어 보아라. 가장 좋은 결과는 바로 이렇다.
year month day rock scissors paper
1991 11 10 0 1 0
1991 11 17 1 0 0
1991 11 24 1 0 0
1991 12 1 0 0 1
...
[[[0. 0. 1.]
[0. 1. 0.]]
[[1. 0. 0.]
[0. 1. 0.]]
[[1. 0. 0.]
[1. 0. 0.]]
...
매개 변수를 바꾸어 보아라. 가장 좋은 결과는 바로 이렇다.
train正解率 : 0.60
test正解率 : 0.59
[0.6323466 0.12851575 0.23913768]
来週の手 : グー
최후
유감스럽게도 이번에는 미스 소라 가위바위보 연구소의 승률에는 못 미치지만 사용 성격과 모델에 따라 개선할 여지가 있다.나는 다시 시간이 있을 때 도전하고 싶다.여러분도 꼭 도전해 보세요.
나는 이번 주 소라 아가씨를 기대하고 있다.
추기
(18/08/18)
코드가 더러워서 정리했어요.
나는 다른 모델에 도전했다.
→ 신경망의 소라 아가씨의 가위바위보 예측에 의하면
(18/08/19)
예쁘다 정답!
참조 링크
소라 아가씨 가위바위보 연구소
다변수 LSTM
신경망의 소라 아가씨의 가위바위보 예측에 의하면
Reference
이 문제에 관하여([파이톤] LSTM 소라 아가씨에 대한 가위바위보 예측.), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다
https://qiita.com/derodero24/items/dbb624cf018041427015
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념
(Collection and Share based on the CC Protocol.)
(18/08/18)
코드가 더러워서 정리했어요.
나는 다른 모델에 도전했다.
→ 신경망의 소라 아가씨의 가위바위보 예측에 의하면
(18/08/19)
예쁘다 정답!
참조 링크
소라 아가씨 가위바위보 연구소
다변수 LSTM
신경망의 소라 아가씨의 가위바위보 예측에 의하면
Reference
이 문제에 관하여([파이톤] LSTM 소라 아가씨에 대한 가위바위보 예측.), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다
https://qiita.com/derodero24/items/dbb624cf018041427015
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념
(Collection and Share based on the CC Protocol.)
Reference
이 문제에 관하여([파이톤] LSTM 소라 아가씨에 대한 가위바위보 예측.), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://qiita.com/derodero24/items/dbb624cf018041427015텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)