Pytorch는 0에서 1로 인코더에서 분리됩니다. (10)

21011 단어 Pytorchdeeplearning
개편
이번에는 변분 인코더에 대해 얘기해 봅시다.변분 인코더도 매우 흔히 볼 수 있는 네트워크 구조의 하나다.그것의 작용은 GAN과 약간 유사한데, 모두 우리를 위해 가짜로 진짜를 어지럽힐 수 있는 그림을 만들어 주는 것이다.그러나 VAE는 GAN과 달리 생성기와 구분기를 구분하지 않고 한 네트워크에서 전체 과정을 완성한다.우리는 먼저 그림을 입력하고 그에 대한 인코딩을 한 다음에 우리의 네트워크 구조를 통해 인코딩의 방차와 균일치를 생성한 다음에 인코딩을 해서 그림을 생성한다. 여기서 가장 중요한 것은 이 방차와 균일치의 생성이다.자신이 방금 한 번 재현해 보니 이곳은 아직도 이해하고 파악해야 할 부분이 많고 좋은 디자인 사고방식이기도 하다.자세한 내용은 VAE 설명서를 참조하십시오.여기 코드만 얘기할게요.
VAE 디코더
라이브러리 도입
import os
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
import torch
from torchvision.utils import save_image

장치 설정 및 그림 저장 주소 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_dir = 'sample_dir'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

하이퍼매개변수 정의
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

여기 설명, hdim은 첫 번째 숨겨진 층, 즉 그림을 입력한 후 첫 번째로 지나간 숨겨진 층의 출력 특징 사이즈를 가리킨다.z_dim는 방차와 균일치를 예측하는 네트워크층의 출력 특징 사이즈를 나타낸다. 여러분은 우리의 균일치와 방차의 사이즈 크기로 이해할 수 있습니다.데이터 준비 및 로드
data = torchvision.datasets.MNIST(root = '../../data/',
                                  download = True,
                                  train = True,
                                  transform = transforms.ToTensor())

data_loader = torch.utils.data.DataLoader(dataset = data,
                                          shuffle = True,
                                          batch_size = batch_size)

VAE 모델 구축
class VAE(nn.Module):
    def __init__(self,image_size,h_dim = 400,z_dim = 20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size,h_dim)
        self.fc2 = nn.Linear(h_dim,z_dim)
        self.fc3 = nn.Linear(h_dim,z_dim)
        self.fc4 = nn.Linear(z_dim,h_dim)
        self.fc5 = nn.Linear(h_dim,image_size)
    def encode(self,x):
        h = F.relu(self.fc1(x))
        return self.fc2(h),self.fc3(h)
    def reparameterize(self,mu,log_var):
        std = torch.exp(log_var / 2)
        ep = torch.randn_like(std)
        return mu + ep * std

    def decode(self,z):
        h = F.relu(self.fc4(z))
        return F.relu(self.fc5(h))
    def forward(self, x):
        mu,log_var = self.encode(x)
        z = self.reparameterize(mu,log_var)
        x_reconst = self.decode(z)
        return x_reconst,mu,log_var

코드 조작은 균일치와 방차를 계산한 다음에 이를 Reparameterize 함수에 전송하여 방차와 균일치를 한 단계 처리하여 방차와 균일치로 구성된 식을 얻어 decode 함수에 전달하면 그림을 디코딩할 수 있다.모델 및 최적화기 정의
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)

훈련 모형
for epoch in range(num_epochs):
    for i,(images,_) in enumerate(data_loader):
        x = images.to(device).view(-1,image_size)
        x_reconst,mu,log_var = model(x)
        loss_reconst = F.binary_cross_entropy(x_reconst,x,size_average=True)
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = x_reconst + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) % 10 == 0:
            print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                  .format(epoch + 1, num_epochs, i + 1, len(data_loader), loss_reconst.item(), kl_div.item()))

우리의 손실은 우리가 다시 생성한 그림과 기존의 실제 그림의 이분 교차 엔트로피 손실로 구성될 뿐만 아니라 KL산도로 구성된다. 이 공식은 위에서 공유한 링크를 참조하고 내부에 상세한 설명이 있다.양자 가화는 우리의 손실을 얻은 다음에 손실을 역방향으로 전파하고 최적화하여 훈련을 완성했다.테스트 모델
with torch.no_grad():
    # Save the sampled images
    #             
    z = torch.randn(batch_size, z_dim).to(device)
    out = model.decode(z).view(-1, 1, 28, 28)
    save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))

    # Save the reconstructed
    out, _, _ = model(x)
    x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
    save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

무작위로 한 그룹의 정적 분포수를 생성하는데 균일치는 0이고 방차는 1이다.그리고 이 인코딩을 우리 모델에 넣고 생성된 그림을 저장한 다음에 실제 그림을 모델에 넣고 둘을 연결하여 비교하기 편리하게 한다.진실한 것을 다시 저장해라.
총결산
VAE는 GAN의 기능과 매우 비슷한 네트워크 구조이다. 우리는 GAN에 대한 이해를 빌려 VAE를 잘 이해할 수 있다. 주로 손실 함수의 유래와 계산 방차와 균일치의 의미를 파악해야 한다.

좋은 웹페이지 즐겨찾기