transforms, dataset, dataloader

7592 단어 딥러닝딥러닝

transforms, dataset, dataloader 3가지를 예전부터 사용하면서 한번 정리해야지 했는데, 요즘 코드 리뷰를 작성하면서 그 필요성이 더욱 와닿았다. 3가지 모두 데이터셋이 모델의 입력으로 들어갈 때, 데이터를 읽어오고 전처리하고 증가시키는 과정에서 사용하는 모듈이다.

transforms

transforms는 기존 이미지의 다양한 변환 기능을 제공하는 역할을 한다. 보통 이미지의 크기를 바꾸거나(ResizedCrop), 반전(Flip) 또는 중심을 기준으로 크롭(CenterCrop)하여 데이터의 양을 늘리는 data augmentation에 많이 사용된다. 다음은 관련 예제이다.

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(256),	# (256x256)
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

먼저 transforms.compose는 transforms에 속해있는 함수들을 한번에 사용할 수 있게 한다.
그 후에 앞에서 소개한 여러가지 방법으로 데이터의 수를 증가시킨다.

또한 ToTensor와 Normalizer를 사용하여 PIL 이미지를 tensor 형식으로 바꾸고, 0~255 범위를 갖는 (h, w, c) 형식을 0~1 범위의 (c, h, w) 형식으로 변환한다. 그리고 주어진 평균, 표준편차 값으로 데이터를 정규화한다. 위의 예제에서는 데이터가 갖는 채널이 rgb 3개이므로 각각의 값을 입력으로 주었음을 알 수 있다.

dataset

dataset은 무엇을 데이터셋으로 주는지 결정하는 역할을 한다. 즉, 전체 데이터셋을 구성하는 역할이다. 이를 위해 직접 dataset 함수를 구현할 때가 많은데, 이때 꼭 필요한 것이 __init__, __len__, __getitem__ 이다. 또한 torchvision의 datasets은 mnist나 cifar10 등 원하는 데이터를 다운받거나, ImageFolder를 이용하여 원하는 형태의 데이터셋을 만들 수 있다.

image_data = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}

위의 transforms 예시에 이어서 torchvision의 datasets.ImageFolder를 설명하면, 데이터셋을 transforms을 거쳐서 나온 형태로 만들고자 할 때, 기존 데이터의 경로와 사용할 transforms을 입력에 넣어주면 원하는 dataset의 형태로 만들어져 사용할 수 있다.

dataloader

이제 완성된 데이터셋을 dataLoader에 넘겨주기만 하면 된다. 전체 데이터셋을 바로 사용하지 않는 이유는, 엄청난 양의 데이터를 한번에 불러오면 메모리와 ram이 이를 감당하지 못하기 때문이다. 그래서 dataloader를 이용하여 데이터를 적절한 크기로 분배하여 모델에 제공한다.

dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], 
                                   batch_size=4, 
                                   shuffle=True, 
                                   num_workers=4) for x in ['train', 'val']
    }

위의 코드를 보면 dataloader에서 정해줘야 하는 주요한 값은 데이터셋과 batch size이다. batch는 한번에 제공되는 데이터의 수를 의미한다. 만약 전체 데이터가 100개이고 batch size가 4 일때, 반복을 뜻하는 iteration은 100/4 = 25가 된다. iteration은 mini batch의 수이기도 하다. 이를 그림으로 표현하면 아래와 같다.

좋은 웹페이지 즐겨찾기