TensorFlow2.0을 사용하여 Fashion-MNIST를 컨벌루션 뉴럴 네트워크(CNN)로 학습(Data Augmentation의 효과 확인)

소개



지난번 , Data Augmentation의 방법을 이해했기 때문에 전회 의 모델에 대해 Data Augmentation을 실시하는 것으로 어떻게 될까를 보고 싶습니다.

환경


  • Google Colaboratory
  • TensorFlow 2.0 Alpha

  • 코드



    여기 입니다.
    왠지 GitHub에서 잘 열지 못했습니다. GitHub URL은 여기입니다.

    코드 해설



    Data Augmentation 설정


    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    #datagen = ImageDataGenerator()
    datagen = ImageDataGenerator(rotation_range = 10, horizontal_flip = True, zoom_range = 0.1)
    

    이번에는 회전, 좌우 반전, 줌을 실시하기로 했습니다. 코멘트행은 Data Augmentation를 실시하지 않는 경우와 정밀도 비교를 하기 위한 것으로, 어느 쪽인지를 코멘트 아웃 해 실행하도록(듯이) 했습니다.

    학습


    import time
    
    num_epoch = 80
    start_time = time.time()
    
    #train_accuracies = []
    #test_accuracies = []
    train_accuracies_with_da = []
    test_accuracies_with_da = []
    
    
    for epoch in range(num_epoch):    
        for image, label in dataset_train:
            for _image, _label in datagen.flow(image, label, batch_size = batch_size):
                train_step(_image, _label)
                break
    
        for test_image, test_label in dataset_test:
            test_step(test_image, test_label)
    
        #train_accuracies.append(train_accuracy.result())
        #test_accuracies.append(test_accuracy.result())
        train_accuracies_with_da.append(train_accuracy.result())
        test_accuracies_with_da.append(test_accuracy.result())
    
    
        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, spent_time: {} min'
        spent_time = time.time() - start_time
        print(template.format(epoch + 1, train_loss.result(), train_accuracy.result() * 100, test_loss.result(), test_accuracy.result() * 100, spent_time / 60))
    

    나중에 그래프를 그리기 위해 train_accuracies_with_da 또는 test_accuracies_with_da에 결과를 저장합니다. Data Augmentation하지 않으려면 주석 처리를 반전시키고 train_accuracies 또는 test_accuracies에 결과를 저장합니다.

    결과





    80 에포크 훈련한 결과, 간신히 Data Augmentation 하는 것이 정밀도가 좋아지고 Data Augmentation 있는 정밀도는 92.2%였습니다. Train Accuracy 쪽은 Data Augmentation 하는지 여부로 과학습도에 차이가 있는 것 같습니다.
    이번에는 그다지 Data Augmentation의 효과는 크지 않았지만, 얼마나 효과를 발휘하는지는 원래 데이터 세트의 수나 특성에 기인한다고 생각합니다.

    다음 이니셔티브



    ResNet에 노력하고 싶습니다.

    좋은 웹페이지 즐겨찾기