PyTorch Dataset에서 인덱스 가져오기

PyTorch의 딥러닝 코드에는 Dataset에서 데이터와 태그를 얻는 방법은 물론 색인 획득 방법도 기재돼 있다.

DatasetWithIndex 클래스의 구현


필자가 조사한 바에 의하면 표준 PyTorch의 기능만으로는 실현할 수 없기 때문에 다음과 같은 데이터 집합의 포장을 실현할 것이다.
DatasetWithIndex.py
class DatasetWithIndex:
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        data, label = self.dataset[index]
        return data, label, index

    def __len__(self):
        return len(self.dataset)

    @property
    def classes(self):
        return self.dataset.classes

코드 다운로드


이 글의 코드는 아래 페이지(Giithub)에서 다운로드할 수 있습니다. 또한, MIT 허가증에 공개됩니다. 자유롭게 수정, 공개 등.
https://github.com/HidetoshiKawaguchi/tech-blog-codes/tree/main/20210619_pytorch-dataset-with-index
아래에 설명된 데모 실행 명령은 다음과 같습니다. 데모를 위해 현재 디렉토리에 MNIST가 다운로드됩니다.
python DatasetWithIndex.py
파이톤 3.8.5, 다음 프로그램 라이브러리에서 동작 확인을 합니다.
torch             1.7.1
torchvision       0.8.2

데모 1


다음과 같은 방법으로 DatasetWithIndex 클래스의 인스턴스를 매개변수로 생성합니다.
그리고 일반 Dataset과 같이 사용할 수 있습니다.
DatasetWithOriginIndex.py
from torch.utils.data import DataLoader
from torchvision import transforms as tt
from torchvision.datasets import MNIST

dataset = MNIST(root='./', train=True, download=True,
                transform=tt.Compose([tt.ToTensor()]))
dataset_with_index = DatasetWithIndex(dataset) # ★データセットをラップしている
data_loader = DataLoader(dataset_with_index, batch_size=4, shuffle=True)

# デモンストレーション1
## 一部データを取得し,あとで取得したインデックスで同じデータにアクセスできるか調べる.
input_list, label_list, index_list = [], [], []
for i, data in enumerate(data_loader):
    inputs, labels, indices = data
    input_list.extend(inputs)
    label_list.extend(labels)
    index_list.extend(indices)
    if i >= 3:
        break
for input, label, index in zip(input_list, label_list, index_list):
    data = dataset_with_index[index]
    # indexの辻褄があっているかを確認
    assert (input == data[0]).all()
    assert data[1] == label
    print("label1 = {}, label2= {}".format(data[1], label))
print("len(dataset_with_index) = {}".format(len(dataset_with_index)))
print("dataset_with_index.classes = {}".format(dataset_with_index.classes))
실행 결과는 다음과 같습니다: 첫 번째 순환에서 MNIST의 이미지, 탭, 인덱스를 얻었습니다. 두 번째 순환에서 첫 번째 순환에서 얻은 인덱스를 사용하여 데이터 set에 접근하여 같은 이미지와 탭을 얻을 수 있는지 확인하십시오. (코드 아래 4, 5줄의 assert)
label1 = 8, label2 = 8
label1 = 9, label2 = 9
label1 = 5, label2 = 5
label1 = 4, label2 = 4
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 9, label2 = 9
label1 = 2, label2 = 2
label1 = 2, label2 = 2
label1 = 8, label2 = 8
label1 = 1, label2 = 1
label1 = 1, label2 = 1
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 8, label2 = 8
label1 = 3, label2 = 3
len(dataset_with_index) = 60000
dataset_with_index.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
보기 위한 태그만 표시되며 표시 결과는 무작위 수에 따라 다릅니다.

프레젠테이션 2: Subset과의 조합


Subset과 결합할 수도 있습니다.
랩 순서에 따라 Subset에서 색인을 가져오거나 원래 색인을 가져오는 것은 다릅니다.
DatasetWithOriginIndex.py
# デモンストレーション2: Subset
from torch.utils.data import Subset

## Subset上のインデックスを取得する
## SubsetをDatasetWithIndexでラップする
subset1 = Subset(dataset, indices=[2, 1, 3, 5, 4])
subset_with_index = DatasetWithIndex(subset1)
print('index on a subset = {}'.format(subset_with_index[0][2]))

## 元のデータセットのインデックスを取得する
## DatasetWithIndexをSubsetでラップする
subset_with_raw_index = Subset(dataset_with_index, [2, 1, 3, 5, 4])
print('index on a raw dataset = {}'.format(subset_with_raw_index[0][2]))
실행 결과는 다음과 같습니다. 두 개 모두 같은 데이터, 탭에 접근했지만, 되돌아오는 색인이 다르기 때문에 좀 번거로울 수 있습니다.
index on a subset = 0
label = 4
index on a raw dataset = 2
label = 4

끝말


데이터와 라벨뿐만 아니라 인덱스도 얻을 수 있는 Dataset의 포장사인 DatasetWithIndex류의 실현에 대해 소개했다.
나는 원시 알고리즘을 실현할 때 사용할 수 있다고 생각한다.

좋은 웹페이지 즐겨찾기