빠른 스타일 이동 원본 분석
9793 단어 풍격이 변하다.
빠른 스타일 이동 원본 분석
간단한 소개
스타일 이동의 원리는 관련 논문을 참고할 수 있고 틈이 나면 원리에 대한 소개 코드 출처를 한 편 쓴다.https://github.com/lengstrom/fast-style-transfer
style.py 해석
style.py는 새로운 스타일을 훈련할 때 입구 파일로 그 기능은 주로 파라미터를 수신하고 훈련 함수를 호출하는 것이다.
변수 수신
def build_parser():
parser = ArgumentParser()
parser.add_argument('--checkpoint-dir', type=str,
dest='checkpoint_dir', help='dir to save checkpoint in',
metavar='CHECKPOINT_DIR', required=True)
parser.add_argument('--style', type=str,
dest='style', help='style image path',
metavar='STYLE', required=True)
parser.add_argument('--train-path', type=str,
dest='train_path', help='path to training images folder',
metavar='TRAIN_PATH', default=TRAIN_PATH)
parser.add_argument('--test', type=str,
dest='test', help='test image path',
metavar='TEST', default=False)
parser.add_argument('--test-dir', type=str,
dest='test_dir', help='test image save dir',
metavar='TEST_DIR', default=False)
parser.add_argument('--slow', dest='slow', action='store_true',
help='gatys\' approach (for debugging, not supported)',
default=False)
parser.add_argument('--epochs', type=int,
dest='epochs', help='num epochs',
metavar='EPOCHS', default=NUM_EPOCHS)
parser.add_argument('--batch-size', type=int,
dest='batch_size', help='batch size',
metavar='BATCH_SIZE', default=BATCH_SIZE)
parser.add_argument('--checkpoint-iterations', type=int,
dest='checkpoint_iterations', help='checkpoint frequency',
metavar='CHECKPOINT_ITERATIONS',
default=CHECKPOINT_ITERATIONS)
parser.add_argument('--vgg-path', type=str,
dest='vgg_path',
help='path to VGG19 network (default %(default)s)',
metavar='VGG_PATH', default=VGG_PATH)
parser.add_argument('--content-weight', type=float,
dest='content_weight',
help='content weight (default %(default)s)',
metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT)
parser.add_argument('--style-weight', type=float,
dest='style_weight',
help='style weight (default %(default)s)',
metavar='STYLE_WEIGHT', default=STYLE_WEIGHT)
parser.add_argument('--tv-weight', type=float,
dest='tv_weight',
help='total variation regularization weight (default %(default)s)',
metavar='TV_WEIGHT', default=TV_WEIGHT)
parser.add_argument('--learning-rate', type=float,
dest='learning_rate',
help='learning rate (default %(default)s)',
metavar='LEARNING_RATE', default=LEARNING_RATE)
return parser
변수 이름
역할
checkpoint-dir
훈련된 모형을 저장하는 경로
style
스타일 맵의 경로
train-path
트레이닝 맵의 경로(COCO2017)
test
테스트 맵 경로
test-dir
테스트 맵 폴더
slow
gatys 사용 방법, debug 사용
epochs
epochs 수량
batch_size
batch_사이즈 싱글 카드 20 정도
checkpoint-iterations
몇 개의 step에 스냅샷을 저장합니까
vgg-path
vgg 모델 파일 경로
content-weight
내용이 무겁다
style-weight
풍격이 중요하다
tv-weight
총변분권이 무겁다?
learning-rate
학습률
함수 호출
style_target = get_img(options.style)
이루어지다
def get_img(src, img_size=False):
img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
if not (len(img.shape) == 3 and img.shape[2] == 3):
img = np.dstack((img,img,img))
if img_size != False:
img = scipy.misc.imresize(img, img_size)
return img
content_targets = _get_files(options.train_path)
구현:
def _get_files(img_dir):
files = list_files(img_dir)
return [os.path.join(img_dir,x) for x in files]
def list_files(in_path):
files = []
for (dirpath, dirnames, filenames) in os.walk(in_path):
files.extend(filenames)
break