Conditional Generative Model 구현 (feat. pytorch)

GAN 이란?

GAN(Generative Adversarial Network)

  • Unsupervised Learning(비지도학습)의 대표적인 알고리즘

  • 서로 대립하는 역할의 두 모델이 경쟁하여 학습하는 방법론

  • G : Generative

    • GAN은 생성모델로 이미지, 음성, sequential data 등 원하는 형태의 데이터를 만드는 모델입니다. 이전에 배운 CNN이나 RNN은 데이터를 분석하는 모델이지 생성하는 모델이 아닙니다.
  • A : Adversarial

    • GAN은 서로 대립관계에 있는 두 개의 모델을 생성해 적대적으로 경쟁시키면서 발전시키는 것이 핵심입니다.

    • Generator(생성자) : 가짜 이미지 생성

    • Discriminator(판별자) : 이미지의 진위 여부 판별

    • generator 모델은 discriminator 모델을 속이기 위해서 진짜 같은 데이터를 만들고, discriminator은 반대로 가짜 데이터와 진짜 데이터를 감별하려고 실력을 키웁니다. 이후 discriminator의 예측 결과에 따라 각 모델의 loss가 결정되고, 서로 학습을 반복합니다. 이런 경쟁구도 속에서 두 모델의 능력이 상호 발전됩니다.

  • N : Network

    • GAN은 인공신경망 모델로, generator와 discriminator 모두 신경망 기반 모델입니다.

기존 GAN의 문제점

기존 GAN의 문제점은 내가 만들고 싶은 데이터를 만들어내지 못한다는 것에 있습니다. 예를 들어 MNIST 데이터로 학습시킨 GAN이 있다고 해봅시다. 이때 GAN을 이용해서 내가 만들고 싶은 숫자를 만들어내지 못합니다. 결국 내가 원하는 숫자가 나올 때까지 입력이 되는 Noise를 계속해서 바꿔줘야만 하는 것입니다.

이러한 문제를 해결하기 위해 Conditional GAN이라는 모델이 등장했습니다. 단순히 어떤 추가적인 정보를 넣어주기만 하면 내가 원하는 데이터를 만들어줄 수 있는 것이죠.

그냥 Discriminator와 Generator에 어떠한 정보 y만 넣어주면 내가 원하는 데이터를 만들어 낼 수 있습니다. 이때 주목해야할 점은 기존의 GAN과 목적함수가 달라진다는 것 입니다. y라는 추가적인 정보가 들어갔으므로, 조건부확률이 된다는 것만 주의하시면 됩니다.


Loss function

기존의 GAN은 다음 식 처럼 minimax 게임을 G와 D가 진행하게 됩니다. 수식에 대한 자세한 설명은 이곳을 참고하세요.


CGAN은 D와 G에 들어가는 input이 단지 조건부로 바뀌기만 하면 됩니다.


  • generator 손실 함수

    • generator은 위 손실 함수는 최소화하는 방향으로 학습을 진행합니다. generator이 생성한 가짜 이미지를 discriminator이 진짜라고 판단하면 D(G(z)) = 1이 됩니다. 따라서 generator은 D(G(z))=1이 되도록 학습을 진행하며, D(G(z))=1이 된다면 위 손실함수는 최소값을 갖습니다.


  • discriminator 손실 함수

    • discriminator은 위 손실 함수를 최대화 하는 방향으로 학습을 진행합니다. 진짜 데이터 x를 discriminator가 진짜라고 판단하면 D(x) = 1의 값을 출력합니다. 반대로 가짜 데이터 G(z)를 discriminator가 가짜라고 판단하면 D(G(z)) = 0의 값을 출력합니다. 즉, 위 손실 함수를 최대화하는 방향으로 학습하는 것은 가짜 데이터를 가짜 데이터로 식별하고, 진짜 데이터는 진짜 데이터를 식별할 수 있도록 파라미터를 갱신합니다.

  • 만약 generator이 진짜같은 가짜 이미지를 생성하게 된다면 진짜와 가짜를 구별할 수 없어 D(x) = D(G(z)) = 1/2 값을 갖습니다.


구현

모듈 및 데이터 가져오기

import json
import os
import numpy as np
import csv
import easydict
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable

from PIL import Image

from tqdm import tqdm
import matplotlib.pyplot as plt
from time import sleep
from torchvision import datasets


transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
])


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=32, shuffle=True)
    
    
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

generator

class Generator(nn.Module):
    """Generator, 논문에 따르면 100개의 noise를 hypercube에서 동일한 확률값으로 뽑고
       z를 200개, y를 1000개의 뉴런으로 전달합니다. 이후 1200차원의 ReLU layer로 결합하고
       Sigmoid를 통해 숫자를 만들어냅니다."""
    def __init__(self):
        super().__init__()
        self.num_classes = 10 # 클래스 수, 10
        self.nz = 100 # 노이즈 수, 100
        self.input_size = (1,28,28)

        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)

        # # noise와 label을 결합할 용도인 label embedding matrix를 생성합니다.
        self.label_emb = nn.Embedding(10, 10)
        # 임베딩 파라미터를 선언한 후 forward 메소드를 수행하면 (입력차원, 임베딩차원) 크기를 가진 텐서가 출력
        # 이때 forward 메소드의 입력텐서는 임베딩 벡터를 추출할 범주의 인덱스이므로 무조건 정수타입(LongTensor)이 들어가야된다.
        # ex) forward 메소드에 (2,4) 크기의 텐서가 입력으로 들어가면 (2,4,10) 크기의 텐서가 출력
        # https://hongl.tistory.com/244
        
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )


    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        z = input.view(input.size(0), 100)  # 노이즈 (batch_size, 100)
        c = self.label_emb(label)  # 라벨 (10,10)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0), 28, 28)

noise의 차원은 100이고, forward를 보면 noise와 label이 input으로 들어갑니다. 그래서 concatenate를 사용하여 100+10차원의 값들을 첫번째 layer에서 처리하도록 합니다. 이미지는 28x28 크기이므로, 마지막 layer에서 view를 이용해 shape을 맞춰줍니다.


discriminator

class Discriminator(nn.Module):
    """Discriminator, 논문에 따르면 maxout을 사용하지만
       여기서는 그냥 Fully-connected와 LeakyReLU를 사용하겠습니다.
       논문에서는 Discriminator의 구조는 그렇게 중요하지 않다고 말합니다"""
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        # 벡터화의 한 과정
        
        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        x = input.view(input.size(0), 784)
        c = self.label_emb(label)
        x = torch.cat([x, c], 1)
        out = self.model(x)
        return out.squeeze()

이때 첫번째 layer를 보면 784+10차원의 값을 받는다는 것을 볼 수 있습니다. 위에서 보셨듯 이미지 하나는 784의 값을 가지고 label은 10의 값을 가지므로, 이들을 함께 넣어주기 위해서 784+10의 값을 가지는 것입니다. 마지막 층 전에는 활성화 함수로 LeakyReLU를 사용하며 Dropout을 사용했습니다.

Discriminator는 어떠한 데이터가 진짜인지 가짜인지 판단해야하므로 한 개의 확률값을 만들어내야만 합니다. 따라서 마지막에는 1개의 값으로 만들어주고 이를 0~1사이의 확률값으로 만들어주기 위해 Sigmoid를 사용했습니다.


Training

generator = Generator().cuda()
discriminator = Discriminator().cuda()

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)


def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad()
    z = Variable(torch.randn(batch_size, 100)).cuda()
    
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()  # 원-핫 벡터 아님
    fake_images = generator(z, fake_labels)  # 가짜 이미지 생성
    
    validity = discriminator(fake_images, fake_labels)  
    # discriminator에 가짜 이미지를 넣어서 결과를 출력. 가짜가 라벨과 같다라고 하면 1, 아니면 0
    
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).cuda())
    # dis에서 나온 출력을 1로 채워져 있는 label과 비교.
    # 만약 dis가 가짜 이미지를 라벨과 같다 판단해서 1을 출력했다면은 loss는 0에 가까워진다.
    # 이를 generator가 학습. 
    # 즉, generator는 discriminator를 잘 속이는 방향(validity 가 1이 나오게)으로 학습
    g_loss.backward()
    g_optimizer.step()
    return g_loss.item()
    
    
    
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    d_optimizer.zero_grad()

    # train with real images
    # 진짜 이미지와 label을 discriminator에 넣는다.
    real_validity = discriminator(real_images, labels)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).cuda())
    # D가 진짜 이미지를 진짜 라고 맞추면 1을 출력하고 real_loss는 0이 됨
    
    # train with fake images
    z = Variable(torch.randn(batch_size, 100)).cuda()  # 임의의 Noise 생성
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()
    
    fake_images = generator(z, fake_labels) # Noise와 임의의 라벨을 input으로 넣음
    
    fake_validity = discriminator(fake_images, fake_labels)
    # generator가 생성한 이미지와 임의의 라벨을 discriminator에 넣어서 결과 출력 (0~1)
    
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).cuda())
    # real_loss를 계산할 때와는 다르게 torch.zeors를 사용해서 0으로 채워진 label을 줌.
    # discriminator가 img와 label이 같다고 판단하면 1을 출력, 아니면 0을 출력하므로 이를 CE하면
    # discriminator가 잘 맞췄을 때는 cross_entropy(1, 0) 이므로 fake_loss가 커짐
    # ---> discriminator가 진짜 이미지가 아닌 generator가 생성한 가짜 이미지를 진짜라고 판단했으므로 loss가 커짐.
    
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss.item()
 
 

from torchvision.utils import make_grid

num_epochs = 100
n_critic = 5
display_step = 50
batch_size = 32
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch), end=' ')
    for i, (images, labels) in enumerate(train_loader):
        
        step = epoch * len(train_loader) + i + 1
        real_images = Variable(images).cuda()
        labels = Variable(labels).cuda()
        generator.train()
        
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        

        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)
        
        #writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': d_loss}, step)  
        
        if step % display_step == 0:
            generator.eval()
            z = Variable(torch.randn(9, 100)).cuda()
            labels = Variable(torch.LongTensor(np.arange(9))).cuda()
            sample_images = generator(z, labels).unsqueeze(1)
            grid = make_grid(sample_images, nrow=3, normalize=True)
         #   writer.add_image('sample_image', grid, step)
    print('Done!')

결과

images = generator(z, labels).unsqueeze(1)

grid = make_grid(images, nrow=10, normalize=True)

fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(grid.permute(1, 2, 0).data.cpu(), cmap='binary')
ax.axis('off')


one-hot 인코딩 & Embedding

위의 구현 코드는 one-hot 인코딩된 label을 사용하지 않고 Embedding을 사용했습니다.
Embedding이라는 말은 NLP에서 매우 자주 등장하는 단어로 이산적, 범주적인 변수를 sparse한 one-hot 인코딩 대신 연속적인 값을 가지는 벡터로 표현하는 방법을 말합니다.

즉, 수많은 종류를 가진 단어, 문장에 대해 one-hot 인코딩을 수행하면 수치로는 표현이 가능하겠지만 대부분의 값이 0이 되어버려 매우 sparse 해지므로 임의의 길이의 실수 벡터로 밀집되게 표현하는 일련의 방법을 임베딩이라 하고, 각 카테고리가 나타내는 실수 벡터를 임베딩 벡터라고 합니다.

만약 one-hot 인코딩을 사용한 코드가 궁금하시면 여기를 참고하시면 됩니다.


참고

좋은 웹페이지 즐겨찾기