tf.data.Dataset을 입력에 ImageDataGenerator를 사용해 Data Augmentation(물 증가)을 실시한다

소개



지난번 , Fashion-MNIST를 CNN으로 분류하고 Test Accuracy: 91.3. 그래서 다음은 Data Augmentation을 실시하는 것으로, 얼마나 정밀도가 오르는지 확인해 보려고 생각했습니다만, tf.data.Dataset를 Data Augmentation하는 방법을 잘 모르기 때문에, 우선은 그것을 조사해 본다 결정했습니다.

환경


  • Google Colaboratory
  • TensorFlow 2.0 Alpha

  • 코드



    여기 입니다.

    코드 해설



    데이터 세트 획득


    import tensorflow_datasets as tfds
    
    batch_size = 64
    
    dataset = tfds.load('cifar10')
    dataset = dataset['train']
    dataset = dataset.batch(batch_size)
    

    이번에는 CIFAR-10을 소재로 해 보았습니다. 사실은 'cats_vs_dogs' 로 하려고 했습니다만, 배치 단위에 종횡의 사이즈를 맞추지 않으면 Data Augmentation에서 에러가 되어 버렸기 때문에 포기했습니다.

    Data Augmentation 설정



    세상에는 데이터가 많이 있지만 라벨이 붙은 데이터는 그렇게 많지 않습니다. 새롭게 라벨을 붙이려고 하면 비용이 듭니다만, 지금 라벨이 붙어 있는 데이터를 조금 가공해 하면 간단하게 라벨 첨부의 데이터를 늘릴 수 있습니다. 그렇게 하여 늘린 보다 많은 데이터를 학습함으로써 일반화 성능이 높아져 정밀도가 올라갈 것으로 기대할 수 있습니다. 데이터의 물 증가 등이라고도합니다. 이것을 Keras의 ImageDataGenerator를 사용해 실현합니다.
    import tensorflow as tf
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    datagen = ImageDataGenerator(rotation_range = 30, horizontal_flip = True, zoom_range = 0.2)
    

    이번에는 회전, 좌우 반전, 줌을 실시하기로 했습니다. 사실은 cutout도 실현하고 싶었습니다만, 공식 문서 을 읽는 한 실현할 수 없는 것 같은 생각이 듭니다.

    Data Augmentation, 결과 표시


    import math
    import numpy as np
    import matplotlib.pyplot as plt
    
    column_size = math.floor(math.sqrt(batch_size))
    row_size = math.ceil(batch_size / column_size)
    fig, ax = plt.subplots(row_size, column_size, figsize = (row_size * 2, column_size * 2), subplot_kw = {'xticks': (), 'yticks': ()})
    
    for axis in ax:
        for a in axis:
            a.set_axis_off()
    
    for data_list in dataset:    
        image_list = data_list['image']
        label_list = data_list['label']
    
        row, column = 0, 0
        for x, y in datagen.flow(image_list, label_list, batch_size = batch_size):
            print(x.shape)
            for _x in x:
                _x = tf.cast(_x, tf.uint8)
                ax[row, column].imshow(_x)
                if column == column_size - 1:
                    column = 0
                    row += 1
                else:
                    column += 1
            break
        break
    
    data_list 에는 batch_size 의 이미지와 레이블 세트가 있습니다. image_listlabel_list 는 ndarray 이고 image_list 의 shape 는 (64, 32, 32, 3) 입니다. 이것을 datagen.flow 의 인수에 주는 것으로, 입력의 이미지를 랜덤으로 회전, 좌우 반전, 줌 한 이미지가 생성되어 x 에 대입됩니다. image_listx 는 일대일 관계입니다. 나머지는 결과만 표시합니다. matplotlib로 테두리나 눈금을 지우는 방법을 알아내는데 수수한 시간이 걸렸습니다만・・・. datagen.flow 는 무한 루프하므로 첫 번째 break 는 필수입니다.

    출력 결과




    왠지 회전이나 줌은 알 수 있을까 생각합니다. 좌우 반전은 모르겠네요・・・.

    마지막으로



    tf.data.Dataset을 입력에 ImageDataGenerator를 사용하고 있는 코드가 별로 세상에 없는 것 같았기 때문에 작성해 보았습니다. 누군가의 도움을 받으면 다행입니다. 다음은 이것을 사용하여 얼마나 정밀도가 오르는지 확인해 보겠습니다.

    좋은 웹페이지 즐겨찾기