tensorflow2의 모델 저장 및 로드(save weights,save 및saved model.save)

8458 단어 TensorFlow2.×

1、save/load weights


네트워크의 매개 변수만 저장하고 다른 상태를 막론하고 이런 모델은 코드에 대한 명확한 인식에 적합하다
사용 절차는 다음과 같습니다.
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')  #        

# Restore the weights
model = create_model()  #       
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images, test_labels)  #   accuracy    
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

예:
network.save_weights('weights.ckpt')
print('saved weights.')
del network

network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)
network.load_weights('weights.ckpt')
network.evaluate(ds_val)

2、save/load entire model


이런 방법은 가장 간단하고 거칠며 모든 모델과 상태를 보존하여 완벽한 회복을 할 수 있다
사용법은 다음과 같습니다.
network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)  #          

network.evaluate(ds_val)

3、saved_model


모델의 저장 형식은pytorch의 ONNX와 대응한다. 즉, 훈련된 하나의 모델을 공장의 생산 환경에 맡길 때 이 모델을 사용자에게 직접 맡겨 배치할 수 있고 원본 코드나 관련 정보를 주지 않아도 된다. 이 모델에 포함된 모든 이런 정보를 포함한다.예를 들어python을 통해 쓴 원본 파일을 c++로 해석하고 읽을 수 있습니다.
사용법은 다음과 같습니다.
tf.saved_model.save(m, '/tmp/saved_model')

imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x = tf.ones([1, 28, 28, 3])))

좋은 웹페이지 즐겨찾기