[PyTorch] Lab04.2 - Loading Data
📌 학습 목표
- Minibatch Gardient Desecent
- PyTorch Dataset & DataLoader
Minibatch Gardient Desecent
실제 세계의 데이터의 양은 방대하고, 학습시킬 때 여러 정보들을 이용한다.
하지만, 이 방대한 양의 데이터들을 한번에 학습시킬 수 없다.
그래서 데이터를 균일하게 나눠 학습하는 Minibatch를 이용한다.
위 그림과 같이 Minibatch로 학습하게 되면 Gradient Descent를 각 Minibatch가 끝날 때마다 수행한다.
매 Minibatch마다 Cost 양이 전체 데이터에 비해 적기 때문에 업데이트가 빨리 이루어진다.
하지만, 데이터가 분할되어 잘못된 방향으로 학습될 수 있다는 점을 주의해야한다.
PyTorch Dataset & DataLoader
Dataset을 Minibatch로 쪼개는데 사용되는 Dataset과 DataLoader를 살펴본다.
-
PyTorch Dataset
torch.utils.data.Dataset
을 상속한다.__len__()
와__getitem__()
메소드를 구현한다.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
self.x_data = [[73, 80, 75],
[93, 88, 93],
[89, 91, 90],
[96, 98, 100],
[73, 66, 70]]
self.y_data = [[152], [185], [180], [196], [142]]
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
x = torch.FloatTensor(self.x_data[idx])
y = torch.FloatTensor(self.y_data[idx])
return x, y
dataset = CustomDataset()
-
PyTorch Dataset
torch.utils.data.DataLoader
를 이용한다batch_size
는 각 minibatch의 크기를 의미한다.shuffle=True
는 Epoch마다 데이터셋을 섞어준다.
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
)
-
full code with Dataset & DataLoader
-
enumerate(dataloader)
minibatch 인덱스와 데이터를 받는다.
-
len(dataloader)
한 epoch당 minibatch 개수이다.(batch_size는 2이므로 3이다.)
-
epochs = 20
for epoch in range(epochs+1):
for batch_idx, samples in enumerate(dataloader):
x_train, y_train = samples
prediction = model(x_train)
cost = F.mse_loss(prediction, y_train)
optimizer.zero_grad()
cost.backward()
optimizer.step()
print('Epoch : {:4d}/{} Batch {}/{} Cost: {:.6f}'.format(
epoch, epochs, batch_idx+1, len(dataloader), cost.item()
))
'''
Epoch : 0/20 Batch 1/3 Cost: 13.613876
Epoch : 0/20 Batch 2/3 Cost: 5.766063
Epoch : 0/20 Batch 3/3 Cost: 4.246419
Epoch : 1/20 Batch 1/3 Cost: 17.660908
...
Epoch : 19/20 Batch 3/3 Cost: 2.818326
Epoch : 20/20 Batch 1/3 Cost: 13.919250
Epoch : 20/20 Batch 2/3 Cost: 6.494051
Epoch : 20/20 Batch 3/3 Cost: 3.017432
'''
Author And Source
이 문제에 관하여([PyTorch] Lab04.2 - Loading Data), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@gun1yun/PyTorch-Lab04.2-Loading-Data저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)