[pytorch] ๐Ÿ˜Ž Custom ImageFolder Class!

60755 ๋‹จ์–ด PyTorchPyTorch

๐Ÿ˜ ์•ˆ๋…•ํ•˜์„ธ์š”, ์˜ค๋Š˜์€ vision ๊ด€๋ จ ๋ชจ๋ธ ์ž‘์„ฑ์‹œ ์š”๊ธดํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜๋Š” ImageFolder Class ์‚ฌ์šฉ๋ฒ•์„ ๊ฐ„๋‹จํžˆ ์•Œ์•„๋ณด๊ณ ,
๐Ÿ˜Š ์ด๋ฅผ ํ™œ์šฉํ•˜์—ฌ Custom Class๋„ ๋งŒ๋“ค์–ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค :)

ImageFolder

๐Ÿ˜‰ Dataset class์˜ ์ผ์ข…์œผ๋กœ์„œ, Data์˜ ๊ฒฝ๋กœ๋งŒ ์ฃผ์–ด์ง€๋ฉด Dataset ๊ฐ์ฒด๋ฅผ ๊ฐ„๋‹จํžˆ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.

https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html?highlight=imagefolder#torchvision.datasets.ImageFolder

๐Ÿถ ๊ทธ๋ฆฌ๊ณ  ์ €๋Š”, ์ง€๋‚œ๋ฒˆ์— ์‚ฌ์šฉํ–ˆ๋˜ Stanford Dog Dataset์„ ํ™œ์šฉํ•˜๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. :)

# ubuntu linux
wget http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
tar -xvf images.tar

์œ„์˜ ๋ช…๋ น์–ด๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉด, images๋ผ๋Š” ํด๋” ์ดํ•˜์— ๊ท€์—ฌ์šด ๊ฐ•์•„์ง€ ์‚ฌ์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


๐Ÿ˜‹ ์ œ๊ฐ€ ์ข‹์•„ํ•˜๋Š” ๊ณจ๋Œ•์ด ์‚ฌ์ง„์ด ์ž˜ ๋‹ค์šด๋กœ๋“œ ๋œ ๊ฒƒ์„ ํ™•์ธํ•˜์˜€์Šต๋‹ˆ๋‹ค!
๐Ÿ˜ƒ ์ด์ œ๋Š” ๊ฐ„๋‹จํžˆ ImageFolder class๋ฅผ ํ™œ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. :)

from torchvision import transforms

size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

dog_transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    (size, size), scale=(0.5, 1.0)),  
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),  # ํ…์„œ๋กœ ๋ณ€ํ™˜
                transforms.Normalize(mean, std)  # ํ‘œ์ค€ํ™”
            ])

dog_dataset = torchvision.datasets.ImageFolder('Images/', transform=dog_transform)

dog_dataset.class_to_idx
>>>{'n02085620-Chihuahua': 0,
 'n02085782-Japanese_spaniel': 1,
 'n02085936-Maltese_dog': 2,
 ...

๐Ÿ“ dataset.class_to_index ๊ฐ’์„ ํ™•์ธํ•˜์—ฌ, Dataset ๊ฐ์ฒด์˜ class์™€ index๊ฐ„ mapping ๊ด€๊ณ„๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

data_loader = torch.utils.data.DataLoader(dog_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=2)
                                          
next(iter(data_loader))[0].shape, next(iter(data_loader))[1].shape
>>> (torch.Size([16, 3, 224, 224]), torch.Size([16]))

โœจ ๊ทธ๋ฆฌ๊ณ , DataLoader class๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์ •์ƒ์ ์œผ๋กœ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ฝ‘ํ˜€๋‚˜์˜ค๋Š” ๊ฒƒ๊นŒ์ง€ ํ™•์ธ ์™„๋ฃŒ์ž…๋‹ˆ๋‹ค :)
๐Ÿ˜œ ์–ด๋•Œ์š”, ์ฐธ ์‰ฝ์ฃ ?

Why Custom Class?

๐Ÿ˜’ ์•„๋‹ˆ ๊ทผ๋ฐ ์™œ Custom Class๋ฅผ ๋งŒ๋“œ๋ ค๊ณ  ํ•˜๋Š”๊ฑฐ์•ผ? ์ง€๊ธˆ๋„ ์ž˜ ๋˜๋Š”๋ฐ? - ๋ผ๊ณ  ์ƒ๊ฐํ•˜์‹ ๋‹ค๋ฉด!
๐Ÿ˜€ ์ €๋„ ๊ทธ๋ ‡๊ฒŒ ์ƒ๊ฐํ–ˆ์Šต๋‹ˆ๋‹ค๋งŒ, ์œ„ Class ์‚ฌ์šฉ์„ ์œ„ํ•ด ํ•˜๋‚˜์˜ ๋งน์ ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค!

"class_to_idx์˜ ์ƒ์„ฑ ๊ธฐ์ค€" ์ด ๋ฐ”๋กœ ๊ทธ๊ฒƒ์ด์—ˆ์ฃ !

๐Ÿค” ๊ฐ„๋‹จํ•˜๊ณ  ๊ฐ•๋ ฅํ•œ ImageFolder Class๋Š” ์ฐธ ์ข‹์ง€๋งŒ, class_to_idx๋Š” "alphabet ์ˆœ์„œ"์— ๋”ฐ๋ผ์„œ index๊ฐ€ ๊ฒฐ์ •๋˜๊ณ  ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.
โœ” ๋งŒ์•ฝ apple / banana / cider 3๊ฐœ label์ด๋ผ๋ฉด, {"apple" : 0, "banana" : 1, "cider" : 2} ์ธ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๐Ÿ˜ƒ ๋ฌผ๋ก , ์•ŒํŒŒ๋ฒณ ์ˆœ์„œ๋Š” ๊ฝค๋‚˜ ๋ณดํŽธ์ ์ธ rule์ด์ง€๋งŒ, ์‹ค ์—…๋ฌด์—์„œ๋Š” ์•ŒํŒŒ๋ฒณ ๋ผ๋ฒจ ์ˆœ์„œ๊ฐ€ ์•„๋‹Œ class_to_idx ๊ธฐ์ค€์œผ๋กœ๋„ ๋ชจ๋ธ ํ•™์Šต์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

์ƒˆ๋กœ์šด class_to_idx ์ œ๊ณต

๐Ÿ˜‹ ํ…Œ์ŠคํŠธ๋ฅผ ์œ„ํ•ด ์ƒˆ๋กœ์šด class_to_idx ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค :)

import os

label_list = os.listdir('Images/')

custom_class_to_idx = {label : idx for idx, label in enumerate(label_list)}

# ์•ŒํŒŒ๋ฒณ ์ˆœ์„œ๋กœ idx ์ง€์ •์ด ๋˜์ง€ ์•Š์€ dict
custom_class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
 'n02111889-Samoyed': 32,
 'n02112018-Pomeranian': 1,
 'n02112137-chow': 97,
 ...
 

๐Ÿ˜Ž ์ด mapping ๊ด€๊ณ„๋ฅผ ํ™œ์šฉํ•˜์—ฌ Dataset Class๋ฅผ ๋งŒ๋“ค ์˜ˆ์ •์ž…๋‹ˆ๋‹ค :)

ImageFolder & DatasetFolder Class์˜ Source ๋ถ„์„

https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder

โœจ ์œ„์˜ ๋งํฌ๋Š” ImageFolder์˜ source code์ž…๋‹ˆ๋‹ค.
๐Ÿ˜Ž ๋”ฐ๋กœ ImageFolder Class๋กœ ์ž‘์„ฑ๋œ ๋‚ด์šฉ๋ณด๋‹ค๋Š” DatasetFolder Class์˜ ๋กœ์ง์„ ๊ทธ๋Œ€๋กœ ์ƒ์†๋ฐ›์•„ ์‚ฌ์šฉํ•˜์˜€๋„ค์š”!

class ImageFolder(DatasetFolder):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples

๐Ÿ˜ ๊ทธ๋ ‡๋‹ค๋ฉด DatasetFolder Class์˜ source code๋„ ํ•œ๋ฒˆ ๋ณด๋„๋ก ํ•˜์ฃ .
๐Ÿคฃ ๋”ํ—›! VisionDataset Class๋ฅผ ๋˜ ์ƒ์†๋ฐ›์•˜๋„ค์š”! ์ด๋ ‡๋‹ค๋ฉด ์ € Class๊นŒ์ง€ ๋ถ„์„ํ•ด์•ผ ํ•˜๋‚˜? ์‹ถ์ง€๋งŒ
๐Ÿ˜Ž ๊ทธ๋Ÿด ํ•„์š”๊นŒ์ง€๋Š” ์—†์Šต๋‹ˆ๋‹ค. ์ €ํฌ์—๊ฒŒ ํ•„์š”ํ•œ๊ฑด ์–ด๋””๊นŒ์ง€๋‚˜ 'class_to_idx' attribute์˜ ์ˆ˜์ •์ด๋‹ˆ๊นŒ์š”.

class DatasetFolder(VisionDataset):
    """
    ์ฃผ์„ ์ƒ๋žต
    """

    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset structured as follows::

            directory/
            โ”œโ”€โ”€ class_x
            โ”‚   โ”œโ”€โ”€ xxx.ext
            โ”‚   โ”œโ”€โ”€ xxy.ext
            โ”‚   โ””โ”€โ”€ ...
            โ”‚       โ””โ”€โ”€ xxz.ext
            โ””โ”€โ”€ class_y
                โ”œโ”€โ”€ 123.ext
                โ”œโ”€โ”€ nsdf3.ext
                โ””โ”€โ”€ ...
                โ””โ”€โ”€ asd932_.ext

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
        """
        return find_classes(directory)       

โœจ ์•„ํ•˜, self.class_to_idx๋Š” self.find_classes ๋ฉ”์„œ๋“œ๋ฅผ ํ†ตํ•ด ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค.
๐Ÿ˜‹ ๊ฒฐ๊ตญ find_classes ๋ฉ”์„œ๋“œ๋งŒ ์ˆ˜์ •์„ ํ•˜๋ฉด ์›ํ•˜๋Š” ๋ฐ”๋ฅผ ์ด๋ฃฐ ์ˆ˜ ์žˆ์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค!

๐Ÿ˜Š ๊ทธ๋ฆฌ๊ณ , find_classes ๋ฉ”์„œ๋“œ์— ์‚ฌ์šฉ๋˜๋Š” find_classes ํ•จ์ˆ˜์˜ ๋‚ด์šฉ์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค :)

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

๐Ÿ‘Œ OK! Dataset Class์— ์‚ฌ์šฉ๋  classes์™€ class_to_idx๋ฅผ returnํ•ฉ๋‹ˆ๋‹ค.
(classes๋Š” idx ์ˆœ์„œ๋Œ€๋กœ class๊ฐ€ ๋‚˜์—ด๋œ list์ž…๋‹ˆ๋‹ค.)

Custom ImageFolder & DatasetFolder Class ์ž‘์„ฑ!

# from https://pytorch.org/vision/0.11/_modules/torchvision/datasets/folder.html
################################################################################
################################################################################
# copied from folder.py

from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
from torchvision.datasets import VisionDataset, DatasetFolder

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')


def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

def make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

# copied from folder.py END

################################################################################
################################################################################


class CustomDatasetFolder(VisionDataset):
    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        class_list: List[str],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(class_list)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:

        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

    def find_classes(self, class_list: List[str]) -> Tuple[List[str], Dict[str, int]]:
        return class_list, {label : idx for idx, label in enumerate(class_list)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)


class CustomImageFolder(CustomDatasetFolder):
    def __init__(
        self,
        root: str,
        class_list: List[str],
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
        root,
        loader,
        class_list,
        IMG_EXTENSIONS if is_valid_file is None else None,
        transform=transform,
        target_transform=target_transform,
        is_valid_file=is_valid_file,
    )
        self.imgs = self.samples

๐Ÿคฃ ํ—ฅํ—ฅ... ์ข€ ์ž‘์„ฑํ•˜๋‹ค๋ณด๋‹ˆ ๊ธธ์–ด์กŒ๋„ค์š” ;<

custom_dog_dataset = CustomImageFolder('Images/', label_list, transform=dog_transform)

custom_dog_dataset.class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
 'n02111889-Samoyed': 32,
 'n02112018-Pomeranian': 1,
 'n02112137-chow': 97,
 ...

โœŒ ์ด๋กœ์จ ์ œ๊ฐ€ ์›ํ•˜๋Š” class_to_idx dict๋ฅผ ๊ธฐ์ค€์œผ๋กœ ๋งŒ๋“ค์–ด์ง„ Dataset class๊ฐ€ ์™„์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!

๋งˆ์น˜๋ฉฐ

๐Ÿ˜‚ ์˜ค๋Š˜ ์ž‘์„ฑํ•œ ๊ธ€์€, ์ƒ๊ฐ๋ณด๋‹ค ์กฐ๊ธˆ ์•„์‰ฌ์šด ๋ฉด์ด ์žˆ์Šต๋‹ˆ๋‹ค.
๐Ÿคฆโ€โ™‚๏ธ ๋‹จ์ˆœํžˆ ์ƒ์†๊ณผ Overriding์„ ํ™œ์šฉํ•˜๋ฉด ๊ฐ„๋‹จํžˆ Custom class๋ฅผ ๋งŒ๋“ค์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ? ํ•˜์˜€์ง€๋งŒ..! ์ปจ์…‰์€ ๊ฐ„๋‹จํ–ˆ์œผ๋‚˜, ์ฝ”๋“œ๋Š” ์ƒ๋‹นํžˆ ๊ธธ์–ด์ ธ๋ฒ„๋ ธ๋„ค์š”.
๐Ÿค” ๋ถ€์กฑํ–ˆ๋˜ ๋ถ€๋ถ„๋„ ์žˆ์„ ์ˆ˜ ์žˆ์„๊ฒƒ ๊ฐ™์•„, ์ฝ”๋“œ๋ฅผ ์ข€๋” ๋ถ„์„ํ•ด๋ณด๊ณ , ๊ฐ„๋‹จ ๋ช…๋ฃŒํ•˜๊ฒŒ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ์•ˆ์ด ์žˆ๋Š”์ง€ 2์ฐจ ๊ฒ€ํ† ๊ฐ€ ํ•„์š”ํ•ด ๋ณด์ž…๋‹ˆ๋‹ค.
๐Ÿ˜‹ ๋ญ... ๊ทธ๋ž˜๋„ ์›ํ•˜๋Š” ๊ฒฐ๊ณผ๋Š” ๋‚˜์™”์œผ๋‹ˆ, ์–ด์จŒ๋“  ๋œ๊ฑฐ ์•„๋‹๊นŒ์š”?! (ํ•˜.ํ•˜.ํ•˜.)
๐Ÿ˜˜ ์ฝ์–ด์ฃผ์…”์„œ ๊ฐ์‚ฌ๋“œ๋ฆฌ๋ฉฐ, ๋˜ ๋ต™๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค!

์ข‹์€ ์›นํŽ˜์ด์ง€ ์ฆ๊ฒจ์ฐพ๊ธฐ