tensorflow2의 모델 저장 및 로드(save weights,save 및saved model.save)
8458 단어 TensorFlow2.×
1、save/load weights
네트워크의 매개 변수만 저장하고 다른 상태를 막론하고 이런 모델은 코드에 대한 명확한 인식에 적합하다
사용 절차는 다음과 같습니다.
# Save the weights
model.save_weights('./checkpoints/my_checkpoint') #
# Restore the weights
model = create_model() #
model.load_weights('./checkpoints/my_checkpoint')
loss, acc = model.evaluate(test_images, test_labels) # accuracy
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
예:
network.save_weights('weights.ckpt')
print('saved weights.')
del network
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights.ckpt')
network.evaluate(ds_val)
2、save/load entire model
이런 방법은 가장 간단하고 거칠며 모든 모델과 상태를 보존하여 완벽한 회복을 할 수 있다
사용법은 다음과 같습니다.
network.save('model.h5')
print('saved total model.')
del network
print('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False) #
network.evaluate(ds_val)
3、saved_model
모델의 저장 형식은pytorch의 ONNX와 대응한다. 즉, 훈련된 하나의 모델을 공장의 생산 환경에 맡길 때 이 모델을 사용자에게 직접 맡겨 배치할 수 있고 원본 코드나 관련 정보를 주지 않아도 된다. 이 모델에 포함된 모든 이런 정보를 포함한다.예를 들어python을 통해 쓴 원본 파일을 c++로 해석하고 읽을 수 있습니다.
사용법은 다음과 같습니다.
tf.saved_model.save(m, '/tmp/saved_model')
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x = tf.ones([1, 28, 28, 3])))