tensorflow 예비 훈련권 다시 가져오기

15987 단어 tensorflow

tensorflow 예비 훈련권 다시 가져오기


1. 코드

def load(data_path, session):
    """
    load the VGG16_pretrain parameters file
    :param data_path:
    :param session:
    :return:
    """
    data_dict = np.load(data_path, encoding='latin1',allow_pickle=True).item()

    keys = sorted(data_dict.keys())
    for key in keys:
        with tf.variable_scope(key, reuse=True):
            for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                session.run(tf.get_variable(subkey).assign(data))


def load_with_skip(data_path, session, skip_layer):
    """
    Only load some layer parameters
    :param data_path:
    :param session:
    :param skip_layer:
    :return:
    """
    data_dict = np.load(data_path, encoding='latin1',allow_pickle=True).item()

    for key in data_dict:
        if key not in skip_layer:
            with tf.variable_scope(key, reuse=True):
                for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                    session.run(tf.get_variable(subkey).assign(data))

이것은 검증으로 한 장의 그림을 입력하여 vgg16의 출력 종류를 판단합니다
import tensorflow as tf
import VGG16 as vgg
from PIL import Image

data_path ='/opt/..../vgg16.npy'

input_maps = tf.placeholder(tf.float32, [None, 224, 224, 3])
prediction,_ = vgg.inference_op(input_maps,1.0)

image = Image.open('weasel.png')
image = image.convert('RGB')
image = image.resize((224,224))
img_raw = image.tobytes()
image = tf.reshape(tf.decode_raw(img_raw,out_type=tf.uint8),[1,224,224,3])
image = tf.cast(image, tf.float32)

# image = tf.read_file('cat.jpg')
# image = tf.image.decode_jpg(image)
# image = tf.image.convert_image_dtype(image,dtype=tf.float32)
# image = tf.image.resize_images(image, size=[224,224])
# image = tf.reshape(image,[1,224,224,3])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    # vgg16.npy 
    vgg.load(data_path, sess)
    image = sess.run(image)
    test_prediction = sess.run([prediction],feed_dict={input_maps:image})
    print(test_prediction)

2. 해석


이전에 네트워크 구조를 정의할 때tf.get_variable()를 사용하여weights와biases를 정의하고 이름은 vgg16입니다.npy의 이름은 상대적이다.변수 이름 공간은 tf.name_scope() 또는 tf.variable_scope() 을 통과할 수 있지만 사용 방법은 다음과 같습니다.
with tf.variable_scope(name):
  kernel = tf.get_variable('weights',shape=[kh,kw,n_input,n_out], dtype=tf.float32,
                   initializer=tf.contrib.layers.xavier_initializer_conv2d())
with tf.name_scope(name) as scope:
  kernel = tf.get_variable(scope+'weights', shape=[n_input, n_out], dtype=tf.float32,
                   initializer=tf.contrib.layers.xavier_initializer_conv2d())
# load 
with tf.variable_scope(key, reuse=True):
                for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                    session.run(tf.get_variable(subkey).assign(data))
tf.name_scope 생성된 변수의 이름에만 영향을 미칠 수 있고 tf.Variabel() 생성된 변수의 이름에는 영향을 미치지 않기 때문이다.또한 tf.get_variabel() 생성된 변수만 공유할 수 있습니다.따라서 Reuse가 True로 설정된 후에 모델이 불러올 때 tf.get_variabel() 협조tf.get_variabel()를 사용하여 이미 훈련된 변수 파라미터를 불러올 수 있습니다.
참조:https://www.jianshu.com/p/14662e980fc0

좋은 웹페이지 즐겨찾기