pytorch 변분 자동 인코더 구현
3513 단어 pytorch
이 예는 MNIST 데이터 세트를 예로 사용합니다.
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
"""
import os
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) #
])
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20) # mean
self.fc22 = nn.Linear(400, 20) # var
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.FloatTensor(std.size()).normal_()
if torch.cuda.is_available():
eps = Variable(eps.cuda())
else:
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.tanh(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x) #
z = self.reparametrize(mu, logvar) #
return self.decode(z), mu, logvar # ,
net = VAE() #
if torch.cuda.is_available():
net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
x = x.cuda()
x = Variable(x)
_, mu, var = net(x)
print(mu)
# , , ,
#
reconstruction_function = nn.MSELoss(size_average=False)
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
MSE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return MSE + KLD
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
def to_img(x):
'''
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1)
x = x.view(x.shape[0], 1, 28, 28)
return x
for e in range(100):
for im, _ in train_data:
im = im.view(im.shape[0], -1)
im = Variable(im)
if torch.cuda.is_available():
im = im.cuda()
recon_im, mu, logvar = net(im)
loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 20 == 0:
print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
save = to_img(recon_im.cpu().data)
if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')
save_image(save, './vae_img/image_{}.png'.format(e + 1))
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
정확도에서 스케일링의 영향데이터셋 스케일링은 데이터 전처리의 주요 단계 중 하나이며, 데이터 변수의 범위를 줄이기 위해 수행됩니다. 이미지와 관련하여 가능한 최소-최대 값 범위는 항상 0-255이며, 이는 255가 최대값임을 의미합니다. 따...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.