심층 학습 (Pytorch)을 이용한 Kaggle Titanic 연습 PART 4 ​​(Pytorch dataloader 및 dataset)

6926 단어 PyTorchDataset
이 장에서는 Pytorch의 Dataset과 DataLoader에 대해 설명합니다.
이 장은 h tps // 고츠치 얀. 하테나 bぉg. 코m/엔트리/2020/04/21/182937 을 참고로 기술되어 있습니다.
Pytorch에서는 Dataset과 DataLoader를 사용하여 쉽게 미니 배치를 할 수 있습니다.

Dataset 구현



DataSet을 구현할 때는 클래스의 멤버 함수로서 len()과 getitem()을 반드시 만듭니다.

len()은 len()을 사용할 때 호출되는 함수입니다.
getitem()은 array[i]와 같이 [ ]를 사용하여 요소를 참조할 때 호출되는 함수입니다. 이것이 불려 갈 때는, 반드시 뭔가의 index가 지정되고 있으므로, 인수에 index의 정보를 취합니다. 또한 I/O 쌍을 반환하도록 설계합니다.

이상을 바탕으로 Dataset을 작성해 봅시다.
class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 入力
        self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # 出力

    def __len__(self):
        return len(self.X) # データ数(10)を返す

    def __getitem__(self, index):
        # index番目の入出力ペアを返す
        return self.X[index], self.t[index]

글쎄, 실제로이 DataSet이 어떻게 작동하는지 시도해 보겠습니다.
dataset = DataSet()
print('全データ数:',len(dataset))  # 全データ数: 10
print('3番目のデータ:',dataset[3]) # 3番目のデータ: (3, 1)
print('5~6番目のデータ:',dataset[5:7]) # 5~6番目のデータ: ([5, 6], [1, 0])

DataLoader 구현



배치 사이즈를 2, 훈련시의 데이터의 셔플을 False로 한 구현은 이하와 같습니다.
# さっき作ったDataSetクラスのインスタンスを作成
dataset = DataSet()
# datasetをDataLoaderの引数とすることでミニバッチを作成.
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)

이제 미니 배치 학습을 할 준비가 되었습니다.
미니 배치용 데이터는 for 문으로 검색할 수 있습니다.
for data in dataloader:
    print(data)

'''
出力:
[tensor([0, 1]), tensor([0, 1])]
[tensor([2, 3]), tensor([0, 1])]
[tensor([4, 5]), tensor([0, 1])]
[tensor([6, 7]), tensor([0, 1])]
[tensor([8, 9]), tensor([0, 1])]
'''

위의 dataloader를 사용하여 10epoch 학습을 하는 경우에는 다음과 같이 쓸 수 있습니다.
epoch = 10
model = #何かしらのモデル
for _ in range(epoch):
    for data in dataloader:
        X = data[0]
        t = data[1]
        y = model(X)
        # lossの計算とか

Pytorch의 Dataset과 Dataloader에 대한 설명은 이상입니다.

좋은 웹페이지 즐겨찾기