[파이톤] LSTM 소라 아가씨에 대한 가위바위보 예측.

카탈로그

  • 개시하다
  • 소스 코드
  • 간단히 설명하다
  • 결실
  • 최후
  • 참조 링크
  • 개시하다


    며칠 전 미스 소라 가위바위보 연구소가 과거 미스 소라의 가위바위보 손 PDF를 공개해 화제를 모았다.그리고 이 연구소는 이 데이터를 분석해 전체 승률이 70% 이상이라는 것을 알게 됐고, 아직 성숙하지는 않지만 기계를 뜯어먹고 공부하는 사람으로서 가만히 있을 수는 없다는 것을 알게 됐다.
    이에 따라 우선 LSTM을 이용해 예측 모델을 간단하게 제작하기로 했다.

    소스 코드


    부기 사항 "다변수 LSTM"을 참고하십시오.
    import numpy as np
    import pandas as pd
    from keras.layers import LSTM, Activation, Dense
    from keras.models import Sequential
    
    data_file = 'サザエさんじゃんけん.tsv'
    look_back = 13  # 遡る時間
    res_file = 'lstm'
    
    
    def shuffle_lists(list1, list2):
        '''リストをまとめてシャッフル'''
        seed = np.random.randint(0, 1000)
        np.random.seed(seed)
        np.random.shuffle(list1)
        np.random.seed(seed)
        np.random.shuffle(list2)
    
    
    def get_data():
        '''データ作成'''
        df = pd.read_csv(data_file, sep='\t',
                         usecols=['rock', 'scissors', 'paper'])
        dataset = df.values.astype(np.float32)
    
        X_data, y_data = [], []
        for i in range(len(dataset) - look_back - 1):
            x = dataset[i:(i + look_back)]
            X_data.append(x)
            y_data.append(dataset[i + look_back])
    
        # X_data = np.array(X_data)
        # y_data = np.array(y_data)
        X_data = np.array(X_data[-500:])
        y_data = np.array(y_data[-500:])
        last_data = np.array([dataset[-look_back:]])
    
        # シャッフル
        shuffle_lists(X_data, y_data)
    
        return X_data, y_data, last_data
    
    
    def get_model():
        model = Sequential()
        model.add(LSTM(16, input_shape=(look_back, 3)))
        model.add(Dense(3))
        model.add(Activation('softmax'))
        model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        return model
    
    
    def pred(model, X, Y, label):
        '''正解率 出力'''
        predictX = model.predict(X)
        correct = 0
        for real, predict in zip(Y, predictX):
            if real.argmax() == predict.argmax():
                correct += 1
        correct = correct / len(Y)
        print(label + '正解率 : %02.2f ' % correct)
    
    
    def main():
        # データ取得
        X_data, y_data, last_data = get_data()
    
        # データ分割
        mid = int(len(X_data) * 0.7)
        train_X, train_y = X_data[:mid], y_data[:mid]
        test_X, test_y = X_data[mid:], y_data[mid:]
    
        # 学習
        model = get_model()
        hist = model.fit(train_X, train_y, epochs=50, batch_size=16,
                         validation_data=(test_X, test_y))
    
        # 正解率出力
        pred(model, train_X, train_y, 'train')
        pred(model, test_X, test_y, 'test')
    
        # 来週の手
        next_hand = model.predict(last_data)
        print(next_hand[0])
        hands = ['グー', 'チョキ', 'パー']
        print('来週の手 : ' + hands[next_hand[0].argmax()])
    
    
    if __name__ == '__main__':
        main()
    

    간단히 설명하다


    데이터는 tsv 파일에 미리 집합됩니다.
    year    month   day rock    scissors    paper
    1991    11  10  0   1   0
    1991    11  17  1   0   0
    1991    11  24  1   0   0
    1991    12  1   0   0   1
    ...
    
    그리고 데이터를 다음과 같이 정형한다(look back = 2시)
    [[[0. 0. 1.]
      [0. 1. 0.]]
     [[1. 0. 0.]
      [0. 1. 0.]]
     [[1. 0. 0.]
      [1. 0. 0.]]
     ...
    

    결실


    매개 변수를 바꾸어 보아라. 가장 좋은 결과는 바로 이렇다.
  • 출력
  • train正解率 : 0.60
    test正解率 : 0.59
    [0.6323466  0.12851575 0.23913768]
    来週の手 : グー
    
  • 차트

  • 최후


    유감스럽게도 이번에는 미스 소라 가위바위보 연구소의 승률에는 못 미치지만 사용 성격과 모델에 따라 개선할 여지가 있다.나는 다시 시간이 있을 때 도전하고 싶다.여러분도 꼭 도전해 보세요.
    나는 이번 주 소라 아가씨를 기대하고 있다.

    추기


    (18/08/18)
    코드가 더러워서 정리했어요.
    나는 다른 모델에 도전했다.
    신경망의 소라 아가씨의 가위바위보 예측에 의하면
    (18/08/19)
    예쁘다 정답!

    참조 링크


    소라 아가씨 가위바위보 연구소
    다변수 LSTM
    신경망의 소라 아가씨의 가위바위보 예측에 의하면

    좋은 웹페이지 즐겨찾기