PyTorch: 데이터 읽기 1 - Datasets
Datasets란 무엇입니까?
입력 라인에서 데이터를 준비하는 코드는 이렇게 적혀 있다
data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10
는 바로 Datasets
자류이고 data
는 이 종류의 실례이다.Datasets를 정의하는 이유
PyTorch
는 도구 함수torch.utils.data.DataLoader
를 제공했다.이 유형을 통해 우리는 데이터를 미니-batch로 바꿀 수 있고 준비mini-batch
할 때 다중 스레드를 병행 처리할 수 있어 데이터 준비의 속도를 가속화할 수 있다.Datasets
는 바로 이 종류를 구축하는 실례적인 매개 변수 중의 하나이다.DataLoader []。
- 유자 껍질 -
Datasets 사용자 정의
프레임
dataset
는 torch.utils.data.Dataset。
내부에서 두 가지 함수를 계승해야 한다. 하나는 __lent__
로 전체 데이터 집합의 크기를 얻는 것이고, 하나는 __getitem__
로 데이터로부터 하나의 데이터 세션을 얻는 것이다item
.import torch.utils.데이터 as data class Custom Dataset(data.Dataset): # 데이터를 상속합니다.Dataset """Custom data.Dataset compatible with data.DataLoader."""
def __init__(self, filename, data_info, oth_params): """Reads source and target sequences from txt files.""" # # # Initialize file path or list of file names. self.file = open(filename,'r') pass # # 또는 외부 데이터 구조에서 데이터info에서 데이터를 읽기self.all_texts = data_info['all_texts'] self.all_labels = data_info['all_labels'] self.vocab = data_info['vocab']
def __getitem__(self, index): """Returns one data pair (source and target)."''# # # 파일에서 # 1을 읽습니다.Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data(e.g. torchvision.Transform 또는word2id 등). # 3. Return a data pair(source and target) (e.g. image and label). pass # # 또는 item 직접 읽기info = { "text": self.all_texts[index], "label": self.all_labels[index] } return item_info
def __len__(self): # You should change 0 to the total size of your dataset. # return 0 return len(self.all_texts)
작은 예
class Dataset(torch.utils.data.Dataset): def __init__(self, filepath=None,dataLen=None): self.file = filepath self.dataLen = dataLen def __getitem__(self, index): A,B,path,hop= linecache.getline(self.file, index+1).split('\t') return A,B,path.split(' '),int(hop)
def __len__(self): return self.dataLen
공식 MNIST의 예.
(코드가 축소되어 중요한 부분만 남았다):
class MNIST(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return 60000
else:
return 10000
from: - 유자 껍질 -
ref: [pytorch 학습 노트(6): 사용자 정의 Datasets]
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Pytorch에서 shuffle을 사용하여 데이터를 흐트러뜨리는 작업우선 내가 너에게 한 가지 알려줘야 할 것은 바로pytorch의tensor이다. 만약에random을 직접 사용한다면.shuffle는 데이터를 흐트러뜨리거나 아래의 방식을 사용하여 직접 쓰기를 정의합니다. 그러면 혼란...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.