GAN 및 Keras(SRGAN)의 초해상도
선험지식
생성 대항 네트워크(GAN)
GAN은 은 은 고드페로와 그의 친구들이 발명한 신경 네트워크 분야의 기술이다.SRGAN는 우리가 어떠한 이미지 해상도를 높일 수 있는 방법이다.
그것은 기본적으로 두 부분으로 구성되어 있는데 그것이 바로 발생기와 감별기이다.생성기는 주어진 입력 소음에 따라 정확한 출력 데이터를 생성한다.감별기는 두 가지 유형의 데이터를 수신한다. 하나는 실제 세계의 데이터이고, 다른 하나는 생성기가 생성한 출력이다.감별기의 경우 실제 데이터에 라벨'1'이 있고 생성된 데이터에 라벨'0'이 있다.우리는 발전기를 예술가에 비유하고, 감별기를 비평가에 비유할 수 있다.예술가는 예술 형식을 창조하여 평론가가 평가한다.
생성기가 훈련 중에 개선됨에 따라 감별기의 성능은 더욱 나빠질 것이다. 왜냐하면 감별기는 진위를 쉽게 구분할 수 없기 때문이다.이론적으로 감별기는 결국 50퍼센트의 정확도를 동전을 던지는 것처럼 갖게 될 것이다.
따라서 우리의 좌우명은 우리를 평가하고 작품을 주목하는 사람들의 정확성을 떨어뜨리는 것이다.
SRGAN의 패브릭
교대 훈련
생성기와 감별기의 훈련 방식은 다르다.우선, 한 개 이상의 역원을 위해 감별기를 훈련하고, 한 개 이상의 역원을 위해 발생기를 훈련한 다음에 한 주기를 완성했다.프리 트레이닝 VG19 모델은 트레이닝 시 이미지에서 피쳐를 추출하는 데 사용됩니다.
트레이닝 생성기에서 감별기의 매개 변수가 동결됩니다. 그렇지 않으면 모델이 이동 목표에 명중하고 영원히 수렴되지 않습니다.
비밀 번호
필요한 의존항 가져오기
import numpy as np
from keras import Model
from keras.layers import Conv2D, PReLU, BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add
일부 필요한 변수lr_ip = Input(shape=(25,25,3))
hr_ip = Input(shape=(100,100,3))
train_lr,train_hr = #training images arrays normalized between 0 & 1
test_lr, test_hr = # testing images arrays normalized between 0 & 1
생성기 정의우리는 고해상도 이미지를 만드는 데 사용되는 생성기 모델을 되돌려 주는 함수를 정의해야 한다.잔차 블록은 하나의 함수로, 그 중에서 하나의 입력 층과 마지막 층의 덧셈을 되돌려준다.
# Residual block
def res_block(ip):
res_model = Conv2D(64, (3,3), padding = "same")(ip)
res_model = BatchNormalization(momentum = 0.5)(res_model)
res_model = PReLU(shared_axes = [1,2])(res_model)
res_model = Conv2D(64, (3,3), padding = "same")(res_model)
res_model = BatchNormalization(momentum = 0.5)(res_model)
return add([ip,res_model])
# Upscale the image 2x
def upscale_block(ip):
up_model = Conv2D(256, (3,3), padding="same")(ip)
up_model = UpSampling2D( size = 2 )(up_model)
up_model = PReLU(shared_axes=[1,2])(up_model)
return up_model
num_res_block = 16
# Generator Model
def create_gen(gen_ip):
layers = Conv2D(64, (9,9), padding="same")(gen_ip)
layers = PReLU(shared_axes=[1,2])(layers)
temp = layers
for i in range(num_res_block):
layers = res_block(layers)
layers = Conv2D(64, (3,3), padding="same")(layers)
layers = BatchNormalization(momentum=0.5)(layers)
layers = add([layers,temp])
layers = upscale_block(layers)
layers = upscale_block(layers)
op = Conv2D(3, (9,9), padding="same")(layers)
return Model(inputs=gen_ip, outputs=op)
감별기 정의이 코드는 감별기 모델의 구조와 실제 이미지를 구분하고 이미지를 생성하는 모든 층을 정의합니다.우리가 깊이 들어가면서, 두 층마다 필터의 수가 두 배로 증가할 것이다.
#Small block inside the discriminator
def discriminator_block(ip, filters, strides=1, bn=True):
disc_model = Conv2D(filters, (3,3), strides, padding="same")(ip)
disc_model = LeakyReLU( alpha=0.2 )(disc_model)
if bn:
disc_model = BatchNormalization( momentum=0.8 )(disc_model)
return disc_model
# Discriminator Model
def create_disc(disc_ip):
df = 64
d1 = discriminator_block(disc_ip, df, bn=False)
d2 = discriminator_block(d1, df, strides=2)
d3 = discriminator_block(d2, df*2)
d4 = discriminator_block(d3, df*2, strides=2)
d5 = discriminator_block(d4, df*4)
d6 = discriminator_block(d5, df*4, strides=2)
d7 = discriminator_block(d6, df*8)
d8 = discriminator_block(d7, df*8, strides=2)
d8_5 = Flatten()(d8)
d9 = Dense(df*16)(d8_5)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(disc_ip, validity)
VGG19 모델이 코드 블록에서는 매개변수가 업데이트되지 않도록 이미지 넷 데이터베이스에서 훈련된 VG19 모델을 사용하여 피쳐를 추출합니다.
from keras.applications import VGG19
# Build the VGG19 model upto 10th layer
# Used to extract the features of high res imgaes
def build_vgg():
vgg = VGG19(weights="imagenet")
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=hr_shape)
img_features = vgg(img)
return Model(img, img_features)
조합 모형지금, 우리는 생성기와 감별기 모형을 동봉합니다.이로부터 얻은 모형은 발전기 모형을 훈련하는 데만 쓰인다.이 조합 모형을 훈련할 때, 우리는 반드시 모든 역원에서 감별기를 동결해야 한다.
# Attach the generator and discriminator
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
gen_img = gen_model(lr_ip)
gen_features = vgg(gen_img)
disc_model.trainable = False
validity = disc_model(gen_img)
return Model([lr_ip, hr_ip],[validity,gen_features])
신고 모형그리고 우리는 생성기, 감별기, vgg 모형을 성명했다.이 모델들은 조합 모델의 매개 변수로 사용될 것이다.
조합 모델 내부의 비교적 작은 모델의 어떤 변경도 외부 모델에 영향을 줄 수 있다.예를 들어 권중 업데이트, 모델 동결 등이다.
generator = create_gen(lr_ip)
discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam",
metrics=['accuracy'])
vgg = build_vgg()
vgg.trainable = False
gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights=
[1e-3, 1], optimizer="adam")
훈련 데이터에 대해 소량의 샘플링을 진행하다훈련 집합이 너무 크기 때문에, 우리는 자원 소모의 오류를 피하기 위해 그림을 소량으로 샘플링해야 한다.RAM과 같은 자원으로는 모든 이미지를 동시에 훈련할 수 없습니다.
batch_size = 20
train_lr_batches = []
train_hr_batches = []
for it in range(int(train_hr.shape[0] / batch_size)):
start_idx = it * batch_size
end_idx = start_idx + batch_size
train_hr_batches.append(train_hr[start_idx:end_idx])
train_lr_batches.append(train_lr[start_idx:end_idx])
train_lr_batches = np.array(train_lr_batches)
train_hr_batches = np.array(train_hr_batches)
훈련 모델이 모듈은 전체 프로그램의 핵심이다.여기서 우리는 상술한 교체 방법으로 감별기와 발생기를 훈련한다.현재 감별기가 동결되었으니 감별기를 훈련하기 전에 동결하는 것을 잊지 말고 감별기를 훈련한 후에 동결하는 것을 아래 코드에서 제시하십시오.
epochs = 100
for e in range(epochs):
gen_label = np.zeros((batch_size, 1))
real_label = np.ones((batch_size,1))
g_losses = []
d_losses = []
for b in range(len(train_hr_batches)):
lr_imgs = train_lr_batches[b]
hr_imgs = train_hr_batches[b]
gen_imgs = generator.predict_on_batch(lr_imgs)
#Dont forget to make the discriminator trainable
discriminator.trainable = True
#Train the discriminator
d_loss_gen = discriminator.train_on_batch(gen_imgs,
gen_label)
d_loss_real = discriminator.train_on_batch(hr_imgs,
real_label)
discriminator.trainable = False
d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)
image_features = vgg.predict(hr_imgs)
#Train the generator
g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs],
[real_label, image_features])
d_losses.append(d_loss)
g_losses.append(g_loss)
g_losses = np.array(g_losses)
d_losses = np.array(d_losses)
g_loss = np.sum(g_losses, axis=0) / len(g_losses)
d_loss = np.sum(d_losses, axis=0) / len(d_losses)
print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)
평가 모델여기에서, 우리는 데이터 집합 계산 생성기의 성능을 테스트합니다.훈련 데이터 집합을 사용했을 때보다 손실이 클 수도 있지만 차이가 적으면 걱정하지 않아도 된다.
label = np.ones((len(test_lr),1))
test_features = vgg.predict(test_hr)
eval,_,_ = gan_model.evaluate([test_lr, test_hr], [label,test_features])
생산량을 예측하다우리는 생성기 모델을 사용하여 고해상도 이미지를 생성할 수 있다.
test_prediction = generator.predict_on_batch(test_lr)
출력이 굉장한데..너는 나의github 파일에서 나의implementation를 찾을 수 있다. 이것은 구글 colab에서 훈련된 것이다.
프롬프트
도구책
제이슨 브라운리.2019. 파이톤의 생성적 대항 네트워크
https://arxiv.org/pdf/1609.04802 . SRGAN에 관한 논문
Reference
이 문제에 관하여(GAN 및 Keras(SRGAN)의 초해상도), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://dev.to/manishdhakal/super-resolution-with-gan-and-keras-srgan-38ma텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)