【Keras 입문(6)】간단한 RNN 모델 정의(최종 출력만 사용)

마지막 기사 「【Keras 입문(5)】간단한 RNN 모델 정의」에서는 RNN을 사용하여 하나의 입력 값에 대해 다음 값을 예측했습니다. 이번에는 10개의 입력값에 대해 하나의 출력을 하는 모델로 합니다.
이것은 문장에 대한 네거티브 포지티브 예측(네거티브/포지티브)이나 문서 분류 등에 사용할 수 있습니다.

이하의 시리즈로 하고 있습니다.

- 【Keras 입문(1)】 단순한 딥 러닝 모델 정의
- 【Keras 입문(2)】훈련 모델 보존(Keras 모델과 SavedModel)
- 【Keras 입문(3)】TensorBoard로 보이기
- 【Keras 입문(4)】Keras의 평가 함수(Metrics)
- 【Keras 입문(5)】간단한 RNN 모델 정의
- 【Keras 입문(6)】간단한 RNN 모델 정의(최종 출력만 사용) <- 본 기사
- 【Keras 입문(7)】간단한 Seq2Seq 모델 정의

사용한 Python 패키지



Google 공동체에서 설치된 다음 패키지와 버전을 사용하고 있습니다. Keras는 TensorFlow에 통합된 것을 사용하기 때문에 순수한 Keras는 사용하지 않습니다. 파이썬은 3.6입니다.
  • tensorflow: 1.14.0
  • Numpy: 1.16.4
  • matplotlib: 3.0.3

  • 처리 개요



    등차 열이 증가할지 또는 감소하는지를 결정합니다.


    1st
    2nd
    3rd
    4th
    5th
    6th
    7th
    8th
    9th
    10th
    등차수 열


    0
    1
    2
    3
    4
    5
    6
    7
    8
    9
    증가(0)

    0
    -1
    -2
    -3
    -4
    -5
    -6
    -7
    -8
    -9
    감소(0)


    처리 프로그램



    전체 프로그램은 GitHub을 참조하십시오.

    1. 라이브러리 가져오기



    이전부터 추가하여 난수를 발생시키는 랜덤도 로드하고 있습니다.
    from random import randint
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    # TensorFlowに統合されたKerasを使用
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense, SimpleRNN
    

    2. 전처리



    등차수열의 배열을 증가·감소 패턴을 교대로 만들고 있습니다.
    NUM_RNN = 10
    NUM_DATA = 200
    
    # 空の器を作成
    x_train = np.empty((0, NUM_RNN))
    y_train = np.empty((0, 1))
    
    for i in range(NUM_DATA):
        num_random = randint(-20, 20)
        if i % 2 == 1:  # 奇数の場合
            x_train = np.append(x_train, np.linspace(num_random, num_random+NUM_RNN-1, num=NUM_RNN).reshape(1, NUM_RNN), axis=0)
            y_train = np.append(y_train, np.zeros(1).reshape(1, 1), axis=0)
        else: # 偶数の場合
            x_train = np.append(x_train, np.linspace(num_random, num_random-NUM_RNN+1, num=NUM_RNN).reshape(1, NUM_RNN), axis=0)
            y_train = np.append(y_train, np.ones(1).reshape(1, 1), axis=0)
    
    x_train = x_train.reshape(NUM_DATA, NUM_RNN, 1)
    y_train = y_train.reshape(NUM_DATA, 1)
    

    3. 모델 정의



    이번 RNN 모델은 이전과 달리 최종 출력만 사용합니다. 그 때문에 아래 그림과 같은 모델입니다.


    참고로, 전회의 모델(최종 출력 이외도 사용)은 이었습니다.


    SimpleRNN 함수return_sequences 값을 False로 사용하지 마십시오. 또한 마지막 전체 결합층은 1차원으로 하여 이진분류입니다.
    NUM_DIM = 16  # 中間層の次元数
    
    model = Sequential()
    
    # return_sequenceがFalseなので最後のRNN層のみが出力を返す
    model.add(SimpleRNN(NUM_DIM, batch_input_shape=(None, NUM_RNN, 1), return_sequences=False))
    model.add(Dense(1, activation='sigmoid'))  #全結合層
    model.compile(loss='binary_crossentropy', optimizer='adam')
    
    model.summary()
    
    summary 함수는 다음과 같은 모델 요약을 제공합니다.
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    simple_rnn (SimpleRNN)       (None, 8)                 80        
    _________________________________________________________________
    dense (Dense)                (None, 1)                 9         
    =================================================================
    Total params: 89
    Trainable params: 89
    Non-trainable params: 0
    _________________________________________________________________
    

    4. 훈련 실행



    fit 함수 사용하여 훈련 실행입니다. 30epoch 정도로 적당히 좋은 정밀도가 나옵니다.
    history = model.fit(x_train, y_train, epochs=30, batch_size=8)
    loss = history.history['loss']
    
    plt.plot(np.arange(len(loss)), loss) # np.arangeはlossの連番数列を生成(今回はepoch数の0から29)
    plt.show()
    



    5. 테스트



    마지막으로 테스트입니다.

    5.1. 테스트 실행



    훈련 데이터의 처음 10건을 테스트 데이터로 합니다.
    predict 함수을 사용하여 테스트 데이터에서 예측값을 출력합니다.
    # データ数(10回)ループ
    for i in range(10):
        y_pred = model.predict(x_train[i].reshape(1, NUM_RNN, 1))
        print(y_pred[0], ':', x_train[i].reshape(NUM_RNN))
    

    본 한 전문 정답입니다(4건만 발췌).
    [0.9673256] : [ -6.  -7.  -8.  -9. -10. -11. -12. -13. -14. -15.]
    [0.01308651] : [-6. -5. -4. -3. -2. -1.  0.  1.  2.  3.]
    [0.9779032] : [12. 11. 10.  9.  8.  7.  6.  5.  4.  3.]
    [0.01788531] : [-2. -1.  0.  1.  2.  3.  4.  5.  6.  7.]
    

    좋은 웹페이지 즐겨찾기