PyTorch: 데이터 읽기 1 - Datasets

- 유자 껍질 -

Datasets란 무엇입니까?


입력 라인에서 데이터를 준비하는 코드는 이렇게 적혀 있다data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True) datasets.CIFAR10는 바로 Datasets자류이고 data는 이 종류의 실례이다.

Datasets를 정의하는 이유

PyTorch는 도구 함수torch.utils.data.DataLoader를 제공했다.이 유형을 통해 우리는 데이터를 미니-batch로 바꿀 수 있고 준비mini-batch할 때 다중 스레드를 병행 처리할 수 있어 데이터 준비의 속도를 가속화할 수 있다.Datasets는 바로 이 종류를 구축하는 실례적인 매개 변수 중의 하나이다.DataLoader []。
- 유자 껍질 -
 

Datasets 사용자 정의


프레임

datasettorch.utils.data.Dataset。 내부에서 두 가지 함수를 계승해야 한다. 하나는 __lent__로 전체 데이터 집합의 크기를 얻는 것이고, 하나는 __getitem__로 데이터로부터 하나의 데이터 세션을 얻는 것이다item.
import torch.utils.데이터 as data class Custom Dataset(data.Dataset): # 데이터를 상속합니다.Dataset     """Custom data.Dataset compatible with data.DataLoader."""
    def __init__(self, filename, data_info, oth_params):         """Reads source and target sequences from txt files."""        # # # Initialize file path or list of file names.         self.file = open(filename,'r') pass # # 또는 외부 데이터 구조에서 데이터info에서 데이터를 읽기self.all_texts = data_info['all_texts']         self.all_labels = data_info['all_labels']         self.vocab = data_info['vocab']
    def __getitem__(self, index):         """Returns one data pair (source and target)."''# # # 파일에서 # 1을 읽습니다.Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).         # 2. Preprocess the data(e.g. torchvision.Transform 또는word2id 등).        # 3. Return a data pair(source and target) (e.g. image and label). pass # # 또는 item 직접 읽기info = {             "text": self.all_texts[index],             "label": self.all_labels[index]         }         return item_info
    def __len__(self):         # You should change 0 to the total size of your dataset.         # return 0         return len(self.all_texts)

작은 예


class Dataset(torch.utils.data.Dataset):     def __init__(self, filepath=None,dataLen=None):         self.file = filepath         self.dataLen = dataLen              def __getitem__(self, index):         A,B,path,hop= linecache.getline(self.file, index+1).split('\t')         return A,B,path.split(' '),int(hop)
    def __len__(self):         return self.dataLen

공식 MNIST의 예.


(코드가 축소되어 중요한 부분만 남았다):
class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000

from: - 유자 껍질 -
ref: [pytorch 학습 노트(6): 사용자 정의 Datasets]

좋은 웹페이지 즐겨찾기