실전 - CycleGAN 학습 파이프라인 구성하기
이전 포스트들에서는 전체적인 그림에 해당하는 CycleGAN 모델과, 모델들을 학습하기 위해 손실을 계산하고 역전파 알고리즘을 수행하는 코드를 구현하였습니다. 그 다음에 필요한 것은 학습에 필요한 데이터셋일 것입니다. 이 포스트에서는 학습을 하기 위한 재료인 데이터셋을 구성하고, 논문에 적혀 있는 세부적인 사항들을 구현해보겠습니다.
학습 데이터셋으로 DataLoader 구성하기
논문에 사용되는 데이터셋은 여기서 받을 수 있습니다. 그 중 apple2orange.zip
파일의 압축을 풀면 다음과 같은 디렉토리 구조를 볼 수 있습니다.
apple2orange
├── trainA
| └── images...
├── trainB
├── testA
└── testB
trainA, trainB 두 폴더의 이미지를 토대로 미니 배치를 구성하게 됩니다. 여기서 unpaired한 데이터셋을 구성하기 위해 두 스타일의 데이터를 랜덤으로 구성하는 작업이 필요합니다. 이는 다음과 같은 구현을 사용합니다.
class UnpairedDataset(DataLoader):
def __init__(self, dataset_dir, styles, transforms):
self.dataset_dir = dataset_dir
self.styles = styles
self.image_path_A = glob.glob(os.path.join(dataset_dir, styles[0]) + "/*")
self.image_path_B = glob.glob(os.path.join(dataset_dir, styles[1]) + "/*")
self.transform = transforms
def __getitem__(self, index_A):
index_B = random.randint(0, len(self.image_path_B) - 1)
item_A = self.transform(Image.open(self.image_path_A[index_A]))
item_B = self.transform(Image.open(self.image_path_B[index_B]))
return [item_A, item_B]
def __len__(self):
return len(self.image_path_A)
- 생성자에서는 root 폴더 내의 두 스타일에 해당하는 학습 폴더들의 리스트를 저장하게 됩니다. 또한, 이미지를 바로 입력에 넣지 않고 텐서로 만들고 정규화하는 작업들을
transform
으로 정의하게 되는데, 여기서 사용하는 변형들은 다음과 같습니다.
transform = transforms.Compose(
[
transforms.Resize((args.size, args.size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]
)
Generator의 제일 뒷단 활성함수가 -1에서 1로 값을 조정하는 tanh이기 때문에, 입력도 -1에서 1로 맞추어주기 위해 Normalize
함수를 구성하게 됩니다. PIL로 불러온 이미지를 ToTensor
에 통과하면 0에서 1 사이이므로, 해당하는 범위로 바꾸어 주는 값을 인수로 지정합니다.
그 외에 데이터 증강 기법으로 무작위하게 가로로 뒤집기도 사용합니다. 가로와 달리 세로 뒤집기를 했을 때는 누가 보아도 티가 나기 때문에 사용하지 않습니다.
-
__getitem__
메서드에서 unpaired 이미지들을 반환하게 됩니다. trainA은 인덱스로 반환하고, trainB를 랜덤으로 선택해서 반환하도록 합니다. -
__len__
는 trainA를 인덱스로 반환하므로, trainA의 크기를 반환하도록 합니다.
이를 사용하기 위해 Dataset과 DataLoader를 train.py에 정의하게 되는데, 이는 다음과 같습니다.
dataloader = DataLoader(
UnpairedDataset(args.dataset_path, ["trainA", "trainB"], transform)
)
스케쥴링 적용하기
그 다음은 스케쥴링입니다. 논문에서 언급하는 부분은 다음과 같습니다.
예를 들어서 처음의 LR이 0.5이고, 500 epoch를 학습한다고 했을 때, 101 epoch에서 0.4, 201 epoch에 0.3이 되는 방식으로 볼 수 있습니다. PyTorch에서는 linear하게 감소하는 스케쥴러를 따로 정의해둔 것은 아니라 별도로 lambda 함수로 LambdaLR
로 다음과 같이 정의합니다.
lr_lambda = lambda epoch: 1 - ((epoch - 1) // 100) / (args.epoch / 100)
scheduler_G = optim.lr_scheduler.LambdaLR(optimizer=optim_G, lr_lambda=lr_lambda)
scheduler_D_A = optim.lr_scheduler.LambdaLR(
optimizer=optim_D_A, lr_lambda=lr_lambda
)
scheduler_D_B = optim.lr_scheduler.LambdaLR(
optimizer=optim_D_B, lr_lambda=lr_lambda
)
epoch가 주어졌을 때 원래 LR에 곱해주는 값을 반환하는 lambda 함수를 만들고, 이를 스케쥴러에 넣는 방식입니다. 세 개의 optimizer를 사용하므로, 스케쥴러도 각각 생성하게 됩니다.
이미지 버퍼 사용하기
논문 내에서 안정적인 학습을 돕기 위해 history 기반 학습을 사용합니다.
즉, 생성한 이미지를 한 번만 사용하는 것이 아니라 버퍼에 저장해서 가져오는 방식으로 보면 되겠습니다. 이는 공식 코드에 사용된 ImagePool
로 사용합니다.
pool_fake_A = ImagePool(args.pool_size)
pred_fake_A = netD_B(pool_fake_A.query(fake_A))
2번째 줄에 fake_A를 버퍼에 넣으면 query
함수에서 버퍼에 저장되어 있는 무작위의 이미지를 반환해줍니다.
가중치 초기화하기
논문 내에서는 다음과 같이 언급되어 있습니다.
자세한 방법은 여기에 언급되어 있습니다. 네 모델에 다음과 같이 적용합니다.
netG_A2B.apply(init_weight)
netG_B2A.apply(init_weight)
netD_A.apply(init_weight)
netD_B.apply(init_weight)
그 외 도움이 되는 구현들
전체적으로 결과 재현에 필요한 학습 파이프라인에 필요한 대부분을 구현하였습니다. 여기서는 부가적으로 구현한 부분들을 간단하게 언급하겠습니다.
- argparse 사용하기: identity loss 사용, dataset path 지정 등 사용하는 유저나 설정 등에 따라 달라질 수 있는데 이를 코드 내에서 다루기보다는 터미널에서 처리하고 싶을 때 사용합니다.
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=500)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--dataset_path", type=str, default="datasets/apple2orange")
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--size", type=int, default=256)
parser.add_argument("--lambda_ide", type=float, default=10)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--pool_size", type=int, default=50)
parser.add_argument("--identity", action="store_true")
args = parser.parse_args()
- checkpoint 불러오기: 한 번에 모든 학습을 진행하면 좋겠지만, 갑자기 꺼질 경우 등에 대비하기 위해 구현합니다. 가중치를 불러오는 경우는 argparse로 지정하도록 합니다. 불러올 때는
load_state_dict
, 새로 학습할 때에는 가중치 초기화를 하는 코드를 구현합니다.
if args.checkpoint_path is not None:
checkpoint = torch.load(args.checkpoint_path, map_location=device)
netG_A2B.load_state_dict(checkpoint["netG_A2B_state_dict"])
netG_B2A.load_state_dict(checkpoint["netG_B2A_state_dict"])
netD_A.load_state_dict(checkpoint["netD_A_state_dict"])
netD_B.load_state_dict(checkpoint["netD_B_state_dict"])
epoch = checkpoint["epoch"]
else:
netG_A2B.apply(init_weight)
netG_B2A.apply(init_weight)
netD_A.apply(init_weight)
netD_B.apply(init_weight)
epoch = 0
- Loss 기록하기: 손실이 줄어드는 것을 확인하면서 학습이 되고 있는지 확인하고 싶을 때 기록하는 코드를 구현해야 합니다. 평균 손실을 기록하고 저장할 때 PyTorch 내의
AverageMeter
를 사용하는 것이 편한데, 여기서는 3개의 손실을 보여줄 것이기 때문에 적은 코드를 사용하기 위해LossDisplayer
를 따로 구현합니다. 또한 파일로 내보내고, 웹 상에서도 확인하기 위해 TensorBoard 기반 함수들을 활용합니다.
class LossDisplayer:
def __init__(self, name_list):
self.count = 0
self.name_list = name_list
self.loss_list = [0] * len(self.name_list)
def record(self, losses):
self.count += 1
for i, loss in enumerate(losses):
self.loss_list[i] += loss.item()
def get_avg_losses(self):
return [loss / self.count for loss in self.loss_list]
def display(self):
for i, total_loss in enumerate(self.loss_list):
avg_loss = total_loss / self.count
print(f"{self.name_list[i]}: {avg_loss:.4f} ", end="")
def reset(self):
self.count = 0
self.loss_list = [0] * len(self.name_list)
disp = LossDisplayer(["G_GAN", "G_recon", "D"])
summary = SummaryWriter()
# Record and display loss
avg_losses = disp.get_avg_losses()
summary.add_scalar("loss_G_GAN", avg_losses[0], epoch)
summary.add_scalar("loss_G_recon", avg_losses[1], epoch)
summary.add_scalar("loss_D", avg_losses[2], epoch)
disp.display()
disp.reset()
학습 전체 코드는 여기서 확인할 수 있습니다.
학습하기
코드를 모두 작성했으면 이제 학습할 시간입니다. 학습을 수행하기 위해 다음과 같이 명령어를 입력하면 됩니다.
CUDA_VISIBLE_DEVICES=0 python3 train.py --identity --dataset_path apple2orange
이는 0번 CUDA를 지원하는 GPU로 identity loss를 사용하고, dataset 경로로 현재 폴더에 있는 apple2orange 폴더를 사용하는 것으로 볼 수 있습니다. 이를 실행하면 다음과 같이 출력됩니다.
cuda
Epoch 1
G_GAN: 0.7555 G_recon: 8.2069 D: 0.6294
Epoch 2
G_GAN: 0.7297 G_recon: 6.4134 D: 0.4776
Epoch 3
G_GAN: 0.7417 G_recon: 5.9897 D: 0.4722
Epoch 4
G_GAN: 0.7261 G_recon: 5.5554 D: 0.4716
Epoch 5
G_GAN: 0.7799 G_recon: 5.3572 D: 0.4590
Epoch 6
G_GAN: 0.7470 G_recon: 5.1694 D: 0.4494
...
이렇게 터미널로 확인할 수도 있지만, 좀 더 쉽게 그래프로 확인하기 위해 기록한 손실들을 불러와서 matplotlib으로 보이도록 해보겠습니다.
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import numpy as np
import matplotlib.pyplot as plt
event_acc = EventAccumulator("runs/[만들어진 로그 파일]")
event_acc.Reload()
loss_G_GAN = np.array([[s.step, s.value] for s in event_acc.Scalars('loss_G_GAN')])
loss_G_recon = np.array([[s.step, s.value] for s in event_acc.Scalars('loss_G_recon')])
loss_D =np.array([[s.step, s.value] for s in event_acc.Scalars('loss_D')])
plt.subplot(131)
plt.title("loss_G_GAN")
plt.plot(loss_G_GAN[:, 0], loss_G_GAN[:, 1])
plt.subplot(132)
plt.title("loss_G_recon")
plt.plot(loss_G_recon[:, 0], loss_G_recon[:, 1])
plt.subplot(133)
plt.title("loss_D")
plt.plot(loss_D[:, 0], loss_D[:, 1])
plt.show()
일단 재구성하는 손실(cycle, identity)가 낮아지는 것으로 보았을 때 학습이 되는 것으로 볼 수 있습니다. 그런데, generator의 GAN 손실은 증가하는 것을 확인할 수 있습니다. 반면 Discriminator는 잘 학습되는 것으로 볼 수 있습니다.
Author And Source
이 문제에 관하여(실전 - CycleGAN 학습 파이프라인 구성하기), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@tjdtnsu/PyTorch-실전-CycleGAN-학습-파이프라인-구성하기저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)