새로운 Data Augmentation 기법 인 CutMix를 사용해 보았습니다.

소개



CutMix라는 새로운 Data Augumentation 기법이 간단한 기법이었기 때문에 시도했습니다.
이번에는 Cifar10 등을 이용하여 이 수법의 유효성을 확인하고 있지 않습니다.
검증에 대해서는 이 기사에 추기, 혹은 새롭게 기사를 일으키고 싶습니다.

CutMix



먼저 CutMix의 이름의 기원으로 Cutout + Mixup에서 왔습니다. 그 유래대로 Cutout과 Mixup의 기술 각각을 맞춘 것 같은 수법이 되고 있습니다.
이하 CutOut과 Mixup, CutMix 각각의 수법의 차이가 비교되고 있는 그림이 논문에 올랐으므로 이쪽에도 게재합니다.



구체적인 처리의 흐름은 이미지와 라벨의 페어 $(x_a, y_a)$, $(x_b, y_b)$로부터, $(x, y)$라고 하는 새로운 데이터와 라벨의 페어를 만듭니다.
여기서 $\lambda\in [0, 1]$은 베타 분포 $Beta(\alpha,\alpha)$에서 샘플링하여 얻고 $\alpha$는 하이퍼파라미터가 됩니다.
또 이미지의 폭과 높이 각각의 균일 분포된 난수를 꺼내, 잘라내는 기점의 좌표 $(r_x, r_y)$로 합니다. 잘라낼 때의 폭과 높이는 얼마만큼 계산한 $\lambda$를 사용해, $\sqrt(1 -\lambda)$로 계산해 그것들을 $(r_w, r_h)$로 합니다.
그 후 잘라내는 기점 $(r_x, r_y)$, $(r_w, r_h)$를 사용해 한쪽의 이미지를 잘라내고, 다른 한쪽의 이미지에 잘라낸 부분을 붙여 넣습니다.이것으로 새로운 이미지는 완성입니다. 라벨 쪽은 mixup의 처리와 같고 $\lambda$와 $(1-\lambda)$를 라벨 각각에 걸어 맞추어 그들을 더하면 처리는 끝이 됩니다.

라고 그다그다라고 썼습니다만 논문에 CutMix의 의사 코드가 써 있었으므로 올려 둡니다.



구현



의사 코드를 구현해 보았습니다.
import numpy as np


def get_rand_bbox(image, l):
    width = image.shape[0]
    height = image.shape[1]
    r_x = np.random.randint(width)
    r_y = np.random.randint(height)
    r_l = np.sqrt(1 - l)
    r_w = np.int(width * r_l)
    r_h = np.int(height * r_l)
    bb_x_1 = np.int(np.clip(r_x - r_w, 0, width))
    bb_y_1 = np.int(np.clip(r_y - r_h, 0, height))
    bb_x_2 = np.int(np.clip(r_x + r_w, 0, width))
    bb_y_2 = np.int(np.clip(r_y + r_h, 0, height))
    return bb_x_1, bb_y_1, bb_x_2, bb_y_2

def main():
    image_path_1 = "image_1.jpg"
    image_path_2 = "image_2.jpg"
    # 説明用にラベルを簡単化しています
    label_1 = np.array([1, 0])
    label_2 = np.array([0, 1])
    image_1 = Image.open(image_path_1).resize((224, 224))
    image_2 = Image.open(image_path_2).resize((224, 224))
    beta = 0.5
    l_param  = np.random.beta(beta, beta)
    img_1 = np.array(image_1)
    img_2 = np.array(image_2)
    bx1, by1, bx2, by2 = get_rand_bbox(img_1, l_param)
    img_2[bx1:bx2, by1:by2, :] = img_1[bx1:bx2, by1:by2, :]
    new_label = l_param * label_2 + (1 - l_param) * label_1



이상이 됩니다.

참고


  • Github - CutMix-PyTorch
  • arXivTimes - CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
  • arXiv - CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
  • Qiita - 새로운 data augmentation 기법 mixup을 시도했습니다.
  • 좋은 웹페이지 즐겨찾기