PyTorch Dataset에서 인덱스 가져오기
DatasetWithIndex 클래스의 구현
필자가 조사한 바에 의하면 표준 PyTorch의 기능만으로는 실현할 수 없기 때문에 다음과 같은 데이터 집합의 포장을 실현할 것이다.
DatasetWithIndex.py
class DatasetWithIndex:
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, index):
data, label = self.dataset[index]
return data, label, index
def __len__(self):
return len(self.dataset)
@property
def classes(self):
return self.dataset.classes
코드 다운로드
이 글의 코드는 아래 페이지(Giithub)에서 다운로드할 수 있습니다. 또한, MIT 허가증에 공개됩니다. 자유롭게 수정, 공개 등.
아래에 설명된 데모 실행 명령은 다음과 같습니다. 데모를 위해 현재 디렉토리에 MNIST가 다운로드됩니다.
python DatasetWithIndex.py
파이톤 3.8.5, 다음 프로그램 라이브러리에서 동작 확인을 합니다.torch 1.7.1
torchvision 0.8.2
데모 1
다음과 같은 방법으로 DatasetWithIndex 클래스의 인스턴스를 매개변수로 생성합니다.
그리고 일반 Dataset과 같이 사용할 수 있습니다.
DatasetWithOriginIndex.py
from torch.utils.data import DataLoader
from torchvision import transforms as tt
from torchvision.datasets import MNIST
dataset = MNIST(root='./', train=True, download=True,
transform=tt.Compose([tt.ToTensor()]))
dataset_with_index = DatasetWithIndex(dataset) # ★データセットをラップしている
data_loader = DataLoader(dataset_with_index, batch_size=4, shuffle=True)
# デモンストレーション1
## 一部データを取得し,あとで取得したインデックスで同じデータにアクセスできるか調べる.
input_list, label_list, index_list = [], [], []
for i, data in enumerate(data_loader):
inputs, labels, indices = data
input_list.extend(inputs)
label_list.extend(labels)
index_list.extend(indices)
if i >= 3:
break
for input, label, index in zip(input_list, label_list, index_list):
data = dataset_with_index[index]
# indexの辻褄があっているかを確認
assert (input == data[0]).all()
assert data[1] == label
print("label1 = {}, label2= {}".format(data[1], label))
print("len(dataset_with_index) = {}".format(len(dataset_with_index)))
print("dataset_with_index.classes = {}".format(dataset_with_index.classes))
실행 결과는 다음과 같습니다: 첫 번째 순환에서 MNIST의 이미지, 탭, 인덱스를 얻었습니다. 두 번째 순환에서 첫 번째 순환에서 얻은 인덱스를 사용하여 데이터 set에 접근하여 같은 이미지와 탭을 얻을 수 있는지 확인하십시오. (코드 아래 4, 5줄의 assert)label1 = 8, label2 = 8
label1 = 9, label2 = 9
label1 = 5, label2 = 5
label1 = 4, label2 = 4
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 9, label2 = 9
label1 = 2, label2 = 2
label1 = 2, label2 = 2
label1 = 8, label2 = 8
label1 = 1, label2 = 1
label1 = 1, label2 = 1
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 8, label2 = 8
label1 = 3, label2 = 3
len(dataset_with_index) = 60000
dataset_with_index.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
보기 위한 태그만 표시되며 표시 결과는 무작위 수에 따라 다릅니다.프레젠테이션 2: Subset과의 조합
Subset과 결합할 수도 있습니다.
랩 순서에 따라 Subset에서 색인을 가져오거나 원래 색인을 가져오는 것은 다릅니다.
DatasetWithOriginIndex.py
# デモンストレーション2: Subset
from torch.utils.data import Subset
## Subset上のインデックスを取得する
## SubsetをDatasetWithIndexでラップする
subset1 = Subset(dataset, indices=[2, 1, 3, 5, 4])
subset_with_index = DatasetWithIndex(subset1)
print('index on a subset = {}'.format(subset_with_index[0][2]))
## 元のデータセットのインデックスを取得する
## DatasetWithIndexをSubsetでラップする
subset_with_raw_index = Subset(dataset_with_index, [2, 1, 3, 5, 4])
print('index on a raw dataset = {}'.format(subset_with_raw_index[0][2]))
실행 결과는 다음과 같습니다. 두 개 모두 같은 데이터, 탭에 접근했지만, 되돌아오는 색인이 다르기 때문에 좀 번거로울 수 있습니다.index on a subset = 0
label = 4
index on a raw dataset = 2
label = 4
끝말
데이터와 라벨뿐만 아니라 인덱스도 얻을 수 있는 Dataset의 포장사인 DatasetWithIndex류의 실현에 대해 소개했다.
나는 원시 알고리즘을 실현할 때 사용할 수 있다고 생각한다.
Reference
이 문제에 관하여(PyTorch Dataset에서 인덱스 가져오기), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://zenn.dev/hidetoshi/articles/20210619_pytorch-dataset-with-index텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)