Tensorflow: 트레이닝 모형 저장 및 복구

4864 단어 tensorflowAI

우선 명확하게 말하자면tensorflow는 무엇을 보존하고 있습니까?
모델이 저장되면 다음과 같은 네 개의 파일이 생성됩니다.
|--models
|    |--checkpoint
|    |--.meta
|    |--.data
|    |--.index

그 가운데.meta는 그림의 구조를 저장하고 checkpoint 파일은 텍스트 파일로 저장된 최신 checkpoint 파일과 기타 checkpoint 파일 목록을 기록합니다.데이터 및.index는 변수 값을 저장합니다.
즉, 모델이 저장하는 것은 그림의 구조와 변수 값이다.

하나의 실례


다음은 tensorflow를 사용하여 간단한 선형 모델을 구현하는 것입니다.
# 
x = np.random.randn(10000,1)
y = 0.03*x+0.8

# 
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

# 
y_predict = tf.add(Weights*xx,bias,name='preds')

# 
loss = tf.reduce_mean(tf.square(yy-y_predict))

# 
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# 
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
#         start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
#         try:
#             end = np.min(start+batchsize,samplesize)
#         except:
#             print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})
        if (i+1)%1000 == 0:
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

 

모델 저장


다음 절차에 따라 저장할 수 있습니다.
saver = tf.train.Saver()
saver.save(session,dir[,global_step])

save에서 첫 번째 파라미터는session이고 두 번째 파라미터는 모델이 저장된 위치이며 세 번째 파라미터는 모델이 교체할 때마다 몇 걸음씩 저장되는지 설명한다.
1의 모델을 저장하고 1000 단계마다 저장하도록 설정합니다.
# 
x = np.random.randn(10000,1)
y = 0.03*x+0.8

# 
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

# 
y_predict = tf.add(Weights*xx,bias,name='preds')

# 
loss = tf.reduce_mean(tf.square(yy-y_predict))

# 
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# 
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    saver = tf.train.Saver()

    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
#         start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
#         try:
#             end = np.min(start+batchsize,samplesize)
#         except:
#             print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})

        # 1000 
        if (i+1)%1000 == 0:
            saver.save(sess,'models\ckp',1000)
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

다음 코드는 1000걸음마다 모델을 저장합니다
if (i+1)%1000 == 0:
    saver.save(sess,'models\ckp',1000)

뜻밖의 경우(예를 들어 훈련 중 갑자기 전기가 끊기는 경우)를 막기 위해 다음 훈련은 처음부터 시작해야 한다.
저장된 디렉토리 구조는 다음과 같습니다.
|--models
|    |--checkpoint
|    |--ckp-1000.meta
|    |--ckp-1000.data-00000-of-00001
|    |--ckp-1000.index

 

삼모형 복구


저장된 meta 파일을 먼저 로드합니다.
saver = tf.train.import_meta_graph(file_name)

복구 매개 변수는session에 의존합니다.dir는 모델이 저장한 디렉터리 경로를 표시합니다. 이 때 모든 장량의 값은session에 있습니다
saver.restore(session,tf.train.latest_checkpoint(dir))

복구된 매개 변수를 가져옵니다. varname은 복구된 매개 변수의 이름을 표시하기 때문에 모든 매개 변수에name 속성을 추가하는 것을 권장합니다.
graph = sess.graph #sess , 
graph.get_tensor_by_name(varname)

다음은 회귀 모델의 회복을 제시하고 훈련된 모델을 이용하여 예측한다.
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models\ckp-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('models'))
    graph = tf.get_default_graph()

    # 
    xx = graph.get_tensor_by_name('xx:0')


    # 
    preds = graph.get_tensor_by_name('preds:0')
    print('predict values:%s' % sess.run(preds,feed_dict={xx:x}))

 
 

 


좋은 웹페이지 즐겨찾기