ImageDataGenerator를 확장하고 cutout 구현

소개



이전 에 ImageDataGenerator를 사용해 Data Augmentation(수증)을 실시했습니다만, ImageDataGenerator가 가지고 있지 않은 물 증가 방법도 사용하고 싶었습니다. 이번에 그것을 실현해 보았습니다.

환경


  • Google Colaboratory
  • TensorFlow 2.0 Alpha

  • 코드



    여기 입니다.

    코드 해설


    import numpy as np
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    class CustomImageDataGenerator(ImageDataGenerator):
        def __init__(self, cutout_mask_size = 0, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.cutout_mask_size = cutout_mask_size
    
        def cutout(self, x, y):
            return np.array(list(map(self._cutout, x))), y
    
        def _cutout(self, image_origin):
            # 最後に使うfill()は元の画像を書き換えるので、コピーしておく
            image = np.copy(image_origin)
            mask_value = image.mean()
    
            h, w, _ = image.shape
            # マスクをかける場所のtop, leftをランダムに決める
            # はみ出すことを許すので、0以上ではなく負の値もとる(最大mask_size // 2はみ出す)
            top = np.random.randint(0 - self.cutout_mask_size // 2, h - self.cutout_mask_size)
            left = np.random.randint(0 - self.cutout_mask_size // 2, w - self.cutout_mask_size)
            bottom = top + self.cutout_mask_size
            right = left + self.cutout_mask_size
    
            # はみ出した場合の処理
            if top < 0:
                top = 0
            if left < 0:
                left = 0
    
            # マスク部分の画素値を平均値で埋める
            image[top:bottom, left:right, :].fill(mask_value)
            return image
    
        def flow(self, *args, **kwargs):
            batches = super().flow(*args, **kwargs)
    
            # 拡張処理
            while True:
                batch_x, batch_y = next(batches)
    
                if self.cutout_mask_size > 0:
                    result = self.cutout(batch_x, batch_y)
                    batch_x, batch_y = result                        
    
                yield (batch_x, batch_y)     
    
    datagen = CustomImageDataGenerator(rotation_range=10, horizontal_flip=True, zoom_range=0.1, cutout_mask_size=16)
    

    ImageDataGenerator를 상속받은 클래스를 만들고 flow 메서드를 재정의하고 거기에서 cutout을 호출합니다.

    출력 결과




    제대로 cutout이 들어있는 것 같습니다.

    참고로 한 페이지


  • Keras의 ImageDataGenerator를 상속하여 Mix-up과 Random Cropping을 할 수 있는 독자적인 제너레이터를 만든다
  • ImageDataGenerator를 확장하여 데이터 확장 패턴 추가
  • 좋은 웹페이지 즐겨찾기