[PyTorch] Lab04.2 - Loading Data

10200 단어 PyTorchPyTorch

📌 학습 목표


  • Minibatch Gardient Desecent
  • PyTorch Dataset & DataLoader

Minibatch Gardient Desecent

실제 세계의 데이터의 양은 방대하고, 학습시킬 때 여러 정보들을 이용한다.

하지만, 이 방대한 양의 데이터들을 한번에 학습시킬 수 없다.

그래서 데이터를 균일하게 나눠 학습하는 Minibatch를 이용한다.

위 그림과 같이 Minibatch로 학습하게 되면 Gradient Descent를 각 Minibatch가 끝날 때마다 수행한다.

Minibatch마다 Cost 양이 전체 데이터에 비해 적기 때문에 업데이트가 빨리 이루어진다.

하지만, 데이터가 분할되어 잘못된 방향으로 학습될 수 있다는 점을 주의해야한다.

PyTorch Dataset & DataLoader

Dataset을 Minibatch로 쪼개는데 사용되는 Dataset과 DataLoader를 살펴본다.

  • PyTorch Dataset

    torch.utils.data.Dataset을 상속한다.

    __len__()__getitem__() 메소드를 구현한다.

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        self.x_data = [[73, 80, 75],
                      [93, 88, 93],
                      [89, 91, 90],
                      [96, 98, 100],
                      [73, 66, 70]]
        self.y_data = [[152], [185], [180], [196], [142]]
        
    def __len__(self):
        return len(self.x_data)
    
    def __getitem__(self, idx):
        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])
        
        return x, y
    
dataset = CustomDataset()
  • PyTorch Dataset

    torch.utils.data.DataLoader 를 이용한다

    • batch_size 는 각 minibatch의 크기를 의미한다.
    • shuffle=TrueEpoch마다 데이터셋을 섞어준다.
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
)
  • full code with Dataset & DataLoader

    • enumerate(dataloader)

      minibatch 인덱스와 데이터를 받는다.

    • len(dataloader)

      한 epoch당 minibatch 개수이다.(batch_size는 2이므로 3이다.)

epochs = 20
for epoch in range(epochs+1):
    for batch_idx, samples in enumerate(dataloader):
        x_train, y_train = samples
        
        prediction = model(x_train)
        
        cost = F.mse_loss(prediction, y_train)
        
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        print('Epoch : {:4d}/{} Batch {}/{} Cost: {:.6f}'.format(
            epoch, epochs, batch_idx+1, len(dataloader), cost.item()
        ))
'''
Epoch :    0/20 Batch 1/3 Cost: 13.613876
Epoch :    0/20 Batch 2/3 Cost: 5.766063
Epoch :    0/20 Batch 3/3 Cost: 4.246419
Epoch :    1/20 Batch 1/3 Cost: 17.660908
...
Epoch :   19/20 Batch 3/3 Cost: 2.818326
Epoch :   20/20 Batch 1/3 Cost: 13.919250
Epoch :   20/20 Batch 2/3 Cost: 6.494051
Epoch :   20/20 Batch 3/3 Cost: 3.017432
'''

좋은 웹페이지 즐겨찾기