tf.keras 입문(5)모델 저장 및 복원
30476 단어 TensorFlowTensorFlow 학습 노트
모델 진 도 는 훈련 기간 과 그 후에 보존 할 수 있다.이것 은 당신 이 지난번 에 중 단 된 곳 에서 훈련 모델 을 계속 해서 훈련 시간 이 너무 길 지 않도록 할 수 있다 는 것 을 의미한다.또한 저장 할 수 있다 는 것 은 모델 을 공유 할 수 있 고 다른 사람 이 귀하 의 업무 성 과 를 재 창작 할 수 있다 는 것 을 의미 합 니 다.연구 모델 과 관련 기술 을 발표 할 때 대부분의 기계 학습 종사자 들 은 다음 과 같은 내용 을 공유 합 니 다.
다음은 tf.keras 를 사용 하여 모델 을 저장 하고 복원 하 는 방법 을 소개 합 니 다.다른 방법 을 알 아 보 려 면 TensorFlow 저장 및 복구 안내 서 를 참조 하거나 Eager 에 저장 할 수 있 습 니 다.
네트워크 구조
이 절 은 keras 의 저장,복구 모델 을 보 여 주 는 방법 일 뿐 입 니 다.전 접속 층 을 사용 합 니 다.
인터페이스 해석
times=1
checkpoint_path = "training_{}/cp.ckpt".format(times) #
checkpoint_dir = os.path.dirname(checkpoint_path) #
# Create checkpoint callback
cp_callback = keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1)
model.fit(train_images,train_labels, epochs =10 ,
validation_data = (test_images, test_labels),
callbacks = [cp_callback])
# TensorFlow , :
model.load_weights(checkpoint_path)
# , 5 :
# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" #
checkpoint_dir = os.path.dirname(checkpoint_path) #
# Create checkpoint callback
cp_callback = keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1,
period=5)# Save weights, every 5-epochs.
model = create_model()
# model.fit(train_images,train_labels, epochs =50 ,
# validation_data = (test_images, test_labels),
# callbacks = [cp_callback],verbose=0)
# Sort the checkpoints by modification time.
checkpoints = pathlib.Path(checkpoint_dir).glob("*.index")
checkpoints = sorted(checkpoints, key=lambda cp:cp.stat().st_mtime)
checkpoints = [cp.with_suffix('') for cp in checkpoints]
latest = str(checkpoints[-1])
# print(checkpoints,'
',latest)
# TensorFlow 5 。
print(latest)
model = create_model()
model.load_weights(latest) #
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# save_weights
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')
전체 모델 저장
전체 모델 은 하나의 파일 에 저장 할 수 있 습 니 다.그 중에서 가중치,모델 설정,심지어 최적화 기 설정 을 포함 합 니 다.이렇게 하면 모델 에 검사 점 을 설정 하고 나중에 똑 같은 상태 에서 훈련 을 계속 할 수 있 으 며 원본 코드 에 접근 하지 않 아 도 된다.Keras 에 정상적으로 사용 할 수 있 는 모델 을 저장 하 는 것 은 매우 유용 합 니 다.그러면 TensorFlow.js 에서 불 러 온 다음 웹 브 라 우 저 에서 훈련 하고 실행 할 수 있 습 니 다.
Keras 는 HDF 5 표준 을 사용 하여 기본 저장 형식 을 제공 합 니 다.저 장 된 모델 을 바 이 너 리 blob 로 볼 수 있 습 니 다.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
Recreate the exact same model, including weights and optimizer:
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# ( )
# Keras 。 , TensorFlow ( tf.train)。
작은 매듭
keras.optimizers.Adam()
와tf.train.AdamOptimizer()
두 가지 최적화 기의 차이import os,pathlib
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
# -1
print(train_images.shape)
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
def create_model():
model = keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(28*28,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10,activation=tf.nn.softmax)
])
model.compile(optimizer=keras.optimizers.Adam(), # tf.train.AdamOptimizer(), tf
# tf load_weights keras (⊙o⊙)…
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = create_model()
model.summary()
times=1
checkpoint_path = "training_{}/cp.ckpt".format(times) #
checkpoint_dir = os.path.dirname(checkpoint_path) #
# Create checkpoint callback
cp_callback = keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1)
# model.fit(train_images,train_labels, epochs =10 ,
# validation_data = (test_images, test_labels),
# callbacks = [cp_callback])
# TensorFlow , :
# , , 。
# ( 10%):
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
# , :
model.load_weights(checkpoint_path)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# epoch
# , 5 :
# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" #
checkpoint_dir = os.path.dirname(checkpoint_path) #
# Create checkpoint callback
cp_callback = keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1,
period=5)# Save weights, every 5-epochs.
model = create_model()
# model.fit(train_images,train_labels, epochs =50 ,
# validation_data = (test_images, test_labels),
# callbacks = [cp_callback],verbose=0)
# Sort the checkpoints by modification time.
checkpoints = pathlib.Path(checkpoint_dir).glob("*.index")
checkpoints = sorted(checkpoints, key=lambda cp:cp.stat().st_mtime)
checkpoints = [cp.with_suffix('') for cp in checkpoints]
latest = str(checkpoints[-1])
# print(checkpoints,'
',latest)
# TensorFlow 5 。
print(latest)
model = create_model()
model.load_weights(latest) #
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# save_weights
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
'''
'''
# , 、
# , , , 。
# Keras ,
# TensorFlow.js , 。
# Keras HDF5 。 , blob。
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
# ( )
# Keras 。 , TensorFlow ( tf.train)。
# https://www.tensorflow.org/guide/keras
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
EMNIST에서 알파벳 필기 인식EMNIST-letters를 배웠습니다. CODE: DEMO: — mbotsu (@mb_otsu) 은 2017년에 NIST가 공개한 데이터세트입니다. EMNIST ByClass: 814,255 characters. ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.