빠른 스타일 이동 원본 분석

9793 단어 풍격이 변하다.

빠른 스타일 이동 원본 분석


간단한 소개


스타일 이동의 원리는 관련 논문을 참고할 수 있고 틈이 나면 원리에 대한 소개 코드 출처를 한 편 쓴다.https://github.com/lengstrom/fast-style-transfer

style.py 해석


style.py는 새로운 스타일을 훈련할 때 입구 파일로 그 기능은 주로 파라미터를 수신하고 훈련 함수를 호출하는 것이다.

변수 수신

def build_parser():
    parser = ArgumentParser()
    parser.add_argument('--checkpoint-dir', type=str,
                        dest='checkpoint_dir', help='dir to save checkpoint in',
                        metavar='CHECKPOINT_DIR', required=True)

    parser.add_argument('--style', type=str,
                        dest='style', help='style image path',
                        metavar='STYLE', required=True)

    parser.add_argument('--train-path', type=str,
                        dest='train_path', help='path to training images folder',
                        metavar='TRAIN_PATH', default=TRAIN_PATH)

    parser.add_argument('--test', type=str,
                        dest='test', help='test image path',
                        metavar='TEST', default=False)

    parser.add_argument('--test-dir', type=str,
                        dest='test_dir', help='test image save dir',
                        metavar='TEST_DIR', default=False)

    parser.add_argument('--slow', dest='slow', action='store_true',
                        help='gatys\' approach (for debugging, not supported)',
                        default=False)

    parser.add_argument('--epochs', type=int,
                        dest='epochs', help='num epochs',
                        metavar='EPOCHS', default=NUM_EPOCHS)

    parser.add_argument('--batch-size', type=int,
                        dest='batch_size', help='batch size',
                        metavar='BATCH_SIZE', default=BATCH_SIZE)

    parser.add_argument('--checkpoint-iterations', type=int,
                        dest='checkpoint_iterations', help='checkpoint frequency',
                        metavar='CHECKPOINT_ITERATIONS',
                        default=CHECKPOINT_ITERATIONS)

    parser.add_argument('--vgg-path', type=str,
                        dest='vgg_path',
                        help='path to VGG19 network (default %(default)s)',
                        metavar='VGG_PATH', default=VGG_PATH)

    parser.add_argument('--content-weight', type=float,
                        dest='content_weight',
                        help='content weight (default %(default)s)',
                        metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT)

    parser.add_argument('--style-weight', type=float,
                        dest='style_weight',
                        help='style weight (default %(default)s)',
                        metavar='STYLE_WEIGHT', default=STYLE_WEIGHT)

    parser.add_argument('--tv-weight', type=float,
                        dest='tv_weight',
                        help='total variation regularization weight (default %(default)s)',
                        metavar='TV_WEIGHT', default=TV_WEIGHT)

    parser.add_argument('--learning-rate', type=float,
                        dest='learning_rate',
                        help='learning rate (default %(default)s)',
                        metavar='LEARNING_RATE', default=LEARNING_RATE)

    return parser

변수 이름
역할
checkpoint-dir
훈련된 모형을 저장하는 경로
style
스타일 맵의 경로
train-path
트레이닝 맵의 경로(COCO2017)
test
테스트 맵 경로
test-dir
테스트 맵 폴더
slow
gatys 사용 방법, debug 사용
epochs
epochs 수량
batch_size
batch_사이즈 싱글 카드 20 정도
checkpoint-iterations
몇 개의 step에 스냅샷을 저장합니까
vgg-path
vgg 모델 파일 경로
content-weight
내용이 무겁다
style-weight
풍격이 중요하다
tv-weight
총변분권이 무겁다?
learning-rate
학습률

함수 호출

  • get_img () 이 함수는 스타일맵을 가져오는 데 비교적 간단합니다.호출:
  • style_target = get_img(options.style)
    

    이루어지다
    def get_img(src, img_size=False):
       img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
       if not (len(img.shape) == 3 and img.shape[2] == 3):
           img = np.dstack((img,img,img))
       if img_size != False:
           img = scipy.misc.imresize(img, img_size)
       return img
    
  • _get_files () 내용도 호출에 사용:
  • content_targets = _get_files(options.train_path)

    구현:
    def _get_files(img_dir):
        files = list_files(img_dir)
        return [os.path.join(img_dir,x) for x in files]
    
    def list_files(in_path):
        files = []
        for (dirpath, dirnames, filenames) in os.walk(in_path):
            files.extend(filenames)
            break
  • optimize()
  • 좋은 웹페이지 즐겨찾기