[pytorch] ๐ Custom ImageFolder Class!
๐ ์๋
ํ์ธ์, ์ค๋์ vision ๊ด๋ จ ๋ชจ๋ธ ์์ฑ์ ์๊ธดํ๊ฒ ์ฌ์ฉ๋๋ ImageFolder Class ์ฌ์ฉ๋ฒ์ ๊ฐ๋จํ ์์๋ณด๊ณ ,
๐ ์ด๋ฅผ ํ์ฉํ์ฌ Custom Class๋ ๋ง๋ค์ด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค :)
ImageFolder
๐ Dataset class์ ์ผ์ข ์ผ๋ก์, Data์ ๊ฒฝ๋ก๋ง ์ฃผ์ด์ง๋ฉด Dataset ๊ฐ์ฒด๋ฅผ ๊ฐ๋จํ ๋ง๋ค ์ ์๋ ํด๋์ค์ ๋๋ค.
๐ถ ๊ทธ๋ฆฌ๊ณ ์ ๋, ์ง๋๋ฒ์ ์ฌ์ฉํ๋ Stanford Dog Dataset์ ํ์ฉํ๋๋ก ํ๊ฒ ์ต๋๋ค. :)
# ubuntu linux
wget http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
tar -xvf images.tar
์์ ๋ช ๋ น์ด๋ฅผ ์ํํ๋ฉด, images๋ผ๋ ํด๋ ์ดํ์ ๊ท์ฌ์ด ๊ฐ์์ง ์ฌ์ง ๋ฐ์ดํฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค.
๐ ์ ๊ฐ ์ข์ํ๋ ๊ณจ๋์ด ์ฌ์ง์ด ์ ๋ค์ด๋ก๋ ๋ ๊ฒ์ ํ์ธํ์์ต๋๋ค!
๐ ์ด์ ๋ ๊ฐ๋จํ ImageFolder class๋ฅผ ํ์ฉํด๋ณด๊ฒ ์ต๋๋ค. :)
from torchvision import transforms
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dog_transform = transforms.Compose([
transforms.RandomResizedCrop(
(size, size), scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # ํ
์๋ก ๋ณํ
transforms.Normalize(mean, std) # ํ์คํ
])
dog_dataset = torchvision.datasets.ImageFolder('Images/', transform=dog_transform)
dog_dataset.class_to_idx
>>>{'n02085620-Chihuahua': 0,
'n02085782-Japanese_spaniel': 1,
'n02085936-Maltese_dog': 2,
...
๐ dataset.class_to_index ๊ฐ์ ํ์ธํ์ฌ, Dataset ๊ฐ์ฒด์ class์ index๊ฐ mapping ๊ด๊ณ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
data_loader = torch.utils.data.DataLoader(dog_dataset,
batch_size=4,
shuffle=True,
num_workers=2)
next(iter(data_loader))[0].shape, next(iter(data_loader))[1].shape
>>> (torch.Size([16, 3, 224, 224]), torch.Size([16]))
โจ ๊ทธ๋ฆฌ๊ณ , DataLoader class๋ฅผ ํ์ฉํ์ฌ ์ ์์ ์ผ๋ก ๋ฐ์ดํฐ๊ฐ ๋ฝํ๋์ค๋ ๊ฒ๊น์ง ํ์ธ ์๋ฃ์
๋๋ค :)
๐ ์ด๋์, ์ฐธ ์ฝ์ฃ ?
Why Custom Class?
๐ ์๋ ๊ทผ๋ฐ ์ Custom Class๋ฅผ ๋ง๋๋ ค๊ณ ํ๋๊ฑฐ์ผ? ์ง๊ธ๋ ์ ๋๋๋ฐ? - ๋ผ๊ณ ์๊ฐํ์ ๋ค๋ฉด!
๐ ์ ๋ ๊ทธ๋ ๊ฒ ์๊ฐํ์ต๋๋ค๋ง, ์ Class ์ฌ์ฉ์ ์ํด ํ๋์ ๋งน์ ์ด ์์์ต๋๋ค!
"class_to_idx์ ์์ฑ ๊ธฐ์ค" ์ด ๋ฐ๋ก ๊ทธ๊ฒ์ด์์ฃ !
๐ค ๊ฐ๋จํ๊ณ ๊ฐ๋ ฅํ ImageFolder Class๋ ์ฐธ ์ข์ง๋ง, class_to_idx๋ "alphabet ์์"์ ๋ฐ๋ผ์ index๊ฐ ๊ฒฐ์ ๋๊ณ ์์์ต๋๋ค.
โ ๋ง์ฝ apple / banana / cider 3๊ฐ label์ด๋ผ๋ฉด, {"apple" : 0, "banana" : 1, "cider" : 2} ์ธ ๊ฒ์
๋๋ค.
๐ ๋ฌผ๋ก , ์ํ๋ฒณ ์์๋ ๊ฝค๋ ๋ณดํธ์ ์ธ rule์ด์ง๋ง, ์ค ์
๋ฌด์์๋ ์ํ๋ฒณ ๋ผ๋ฒจ ์์๊ฐ ์๋ class_to_idx ๊ธฐ์ค์ผ๋ก๋ ๋ชจ๋ธ ํ์ต์ด ํ์ํ ๊ฒฝ์ฐ๊ฐ ์์์ต๋๋ค.
์๋ก์ด class_to_idx ์ ๊ณต
๐ ํ ์คํธ๋ฅผ ์ํด ์๋ก์ด class_to_idx ๊ฐ์ฒด๋ฅผ ๋ง๋ค์์ต๋๋ค :)
import os
label_list = os.listdir('Images/')
custom_class_to_idx = {label : idx for idx, label in enumerate(label_list)}
# ์ํ๋ฒณ ์์๋ก idx ์ง์ ์ด ๋์ง ์์ dict
custom_class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
'n02111889-Samoyed': 32,
'n02112018-Pomeranian': 1,
'n02112137-chow': 97,
...
๐ ์ด mapping ๊ด๊ณ๋ฅผ ํ์ฉํ์ฌ Dataset Class๋ฅผ ๋ง๋ค ์์ ์ ๋๋ค :)
ImageFolder & DatasetFolder Class์ Source ๋ถ์
https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder
โจ ์์ ๋งํฌ๋ ImageFolder์ source code์
๋๋ค.
๐ ๋ฐ๋ก ImageFolder Class๋ก ์์ฑ๋ ๋ด์ฉ๋ณด๋ค๋ DatasetFolder Class์ ๋ก์ง์ ๊ทธ๋๋ก ์์๋ฐ์ ์ฌ์ฉํ์๋ค์!
class ImageFolder(DatasetFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
๐ ๊ทธ๋ ๋ค๋ฉด DatasetFolder Class์ source code๋ ํ๋ฒ ๋ณด๋๋ก ํ์ฃ .
๐คฃ ๋ํ! VisionDataset Class๋ฅผ ๋ ์์๋ฐ์๋ค์! ์ด๋ ๋ค๋ฉด ์ Class๊น์ง ๋ถ์ํด์ผ ํ๋? ์ถ์ง๋ง
๐ ๊ทธ๋ด ํ์๊น์ง๋ ์์ต๋๋ค. ์ ํฌ์๊ฒ ํ์ํ๊ฑด ์ด๋๊น์ง๋ 'class_to_idx' attribute์ ์์ ์ด๋๊น์.
class DatasetFolder(VisionDataset):
"""
์ฃผ์ ์๋ต
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
directory/
โโโ class_x
โ โโโ xxx.ext
โ โโโ xxy.ext
โ โโโ ...
โ โโโ xxz.ext
โโโ class_y
โโโ 123.ext
โโโ nsdf3.ext
โโโ ...
โโโ asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
return find_classes(directory)
โจ ์ํ, self.class_to_idx๋ self.find_classes ๋ฉ์๋๋ฅผ ํตํด ๊ฒฐ์ ๋ฉ๋๋ค.
๐ ๊ฒฐ๊ตญ find_classes ๋ฉ์๋๋ง ์์ ์ ํ๋ฉด ์ํ๋ ๋ฐ๋ฅผ ์ด๋ฃฐ ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค!
๐ ๊ทธ๋ฆฌ๊ณ , find_classes ๋ฉ์๋์ ์ฌ์ฉ๋๋ find_classes ํจ์์ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค :)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
๐ OK! Dataset Class์ ์ฌ์ฉ๋ classes์ class_to_idx๋ฅผ returnํฉ๋๋ค.
(classes๋ idx ์์๋๋ก class๊ฐ ๋์ด๋ list์
๋๋ค.)
Custom ImageFolder & DatasetFolder Class ์์ฑ!
# from https://pytorch.org/vision/0.11/_modules/torchvision/datasets/folder.html
################################################################################
################################################################################
# copied from folder.py
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
from torchvision.datasets import VisionDataset, DatasetFolder
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
# copied from folder.py END
################################################################################
################################################################################
class CustomDatasetFolder(VisionDataset):
def __init__(
self,
root: str,
loader: Callable[[str], Any],
class_list: List[str],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(class_list)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, class_list: List[str]) -> Tuple[List[str], Dict[str, int]]:
return class_list, {label : idx for idx, label in enumerate(class_list)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
class CustomImageFolder(CustomDatasetFolder):
def __init__(
self,
root: str,
class_list: List[str],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
class_list,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
๐คฃ ํฅํฅ... ์ข ์์ฑํ๋ค๋ณด๋ ๊ธธ์ด์ก๋ค์ ;<
custom_dog_dataset = CustomImageFolder('Images/', label_list, transform=dog_transform)
custom_dog_dataset.class_to_idx
>>> {'n02111500-Great_Pyrenees': 0,
'n02111889-Samoyed': 32,
'n02112018-Pomeranian': 1,
'n02112137-chow': 97,
...
โ ์ด๋ก์จ ์ ๊ฐ ์ํ๋ class_to_idx dict๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ง๋ค์ด์ง Dataset class๊ฐ ์์ฑ๋์์ต๋๋ค!
๋ง์น๋ฉฐ
๐ ์ค๋ ์์ฑํ ๊ธ์, ์๊ฐ๋ณด๋ค ์กฐ๊ธ ์์ฌ์ด ๋ฉด์ด ์์ต๋๋ค.
๐คฆโโ๏ธ ๋จ์ํ ์์๊ณผ Overriding์ ํ์ฉํ๋ฉด ๊ฐ๋จํ Custom class๋ฅผ ๋ง๋ค์ ์์ง ์์๊น? ํ์์ง๋ง..! ์ปจ์
์ ๊ฐ๋จํ์ผ๋, ์ฝ๋๋ ์๋นํ ๊ธธ์ด์ ธ๋ฒ๋ ธ๋ค์.
๐ค ๋ถ์กฑํ๋ ๋ถ๋ถ๋ ์์ ์ ์์๊ฒ ๊ฐ์, ์ฝ๋๋ฅผ ์ข๋ ๋ถ์ํด๋ณด๊ณ , ๊ฐ๋จ ๋ช
๋ฃํ๊ฒ ํ์ฉํ ์ ์๋ ๋ฐฉ์์ด ์๋์ง 2์ฐจ ๊ฒํ ๊ฐ ํ์ํด ๋ณด์
๋๋ค.
๐ ๋ญ... ๊ทธ๋๋ ์ํ๋ ๊ฒฐ๊ณผ๋ ๋์์ผ๋, ์ด์จ๋ ๋๊ฑฐ ์๋๊น์?! (ํ.ํ.ํ.)
๐ ์ฝ์ด์ฃผ์
์ ๊ฐ์ฌ๋๋ฆฌ๋ฉฐ, ๋ ๋ต๋๋ก ํ๊ฒ ์ต๋๋ค!
Author And Source
์ด ๋ฌธ์ ์ ๊ดํ์ฌ([pytorch] ๐ Custom ImageFolder Class!), ์ฐ๋ฆฌ๋ ์ด๊ณณ์์ ๋ ๋ง์ ์๋ฃ๋ฅผ ๋ฐ๊ฒฌํ๊ณ ๋งํฌ๋ฅผ ํด๋ฆญํ์ฌ ๋ณด์๋ค https://velog.io/@gtpgg1013/pytorch-Custom-ImageFolder-Class์ ์ ๊ท์: ์์์ ์ ๋ณด๊ฐ ์์์ URL์ ํฌํจ๋์ด ์์ผ๋ฉฐ ์ ์๊ถ์ ์์์ ์์ ์ ๋๋ค.
์ฐ์ํ ๊ฐ๋ฐ์ ์ฝํ ์ธ ๋ฐ๊ฒฌ์ ์ ๋ (Collection and Share based on the CC Protocol.)
์ข์ ์นํ์ด์ง ์ฆ๊ฒจ์ฐพ๊ธฐ
๊ฐ๋ฐ์ ์ฐ์ ์ฌ์ดํธ ์์ง
๊ฐ๋ฐ์๊ฐ ์์์ผ ํ ํ์ ์ฌ์ดํธ 100์ ์ถ์ฒ ์ฐ๋ฆฌ๋ ๋น์ ์ ์ํด 100๊ฐ์ ์์ฃผ ์ฌ์ฉํ๋ ๊ฐ๋ฐ์ ํ์ต ์ฌ์ดํธ๋ฅผ ์ ๋ฆฌํ์ต๋๋ค