Tensorflow 0.9.0에서 개발한 코드가 Tensorflow 1.0.0 버전으로 업그레이드된 경험
3815 단어 TensorFlow
1. tensorfow.models 모듈은 1.0.0 버전에서 단독으로 분리되어 seq2seq 를 호출합니다model.py,seq2seq_model_utils.py와 데이터utils.py 등 파일 인터페이스의 코드는 수정이 필요합니다.
관련 변경 사항은 다음과 같습니다.
1)from tensorflow를 삭제합니다.models.rnn.translate import data_utils 문장,tensorflow.models.rnn.translate의 데이터utils 코드를 복사해서 import을 가져옵니다.tensorflow.models 모듈에 대응하는 데이터utils 코드 링크는https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/data_utils.py 2)
def sampled_loss(inputs, labels):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
self.target_vocab_size)
softmax_loss_function = sampled_loss
... 로 바꾸다
4
def sampled_loss(labels, inputs):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, labels, inputs, num_samples,
self.target_vocab_size)
softmax_loss_function = sampled_loss
원인은 Sampledsoftmax_loss 방법의 labels와 inputs 매개 변수가 위치를 바꿨습니다3)
single_cell = tf.nn.rnn_cell.GRUCell(size)
if use_lstm:
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
cell = single_cell
if num_layers > 1:
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
... 로 바꾸다
single_cell = tf.contrib.rnn.GRUCell(size)
if use_lstm:
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
cell = single_cell
if num_layers > 1:
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
tf.nn.rnn_cell.*tf.nn.rnn.*의 대다수 함수(dynamic rnn과 raw rnn 제외)가 1.0 버전에서 tf로 일시적으로 이동합니다.contrib.nn에서 1.1 버전이 다시 이동됩니다.
4)
tf.nn.seq2seq.embedding_attention_seq2seq tf.contrib.legacy_seq2seq.embedding_attention_seq2seq
tf.nn.seq2seq.model_with_buckets tf.contrib.legacy_seq2seq.model_with_buckets
tf.nn.seq2seq.sequence_loss_by_example tf.contrib.legacy_seq2seq.sequence_loss_by_example
Tensorflow1.0.0버전 이후 새로운 seq2seq 인터페이스를 개발하여 tf.contrib.seq2seq에서 원래의 인터페이스를 버리고 낡은 인터페이스를 tf로 옮깁니다.contrib.legacy_새 인터페이스는 동적 전개이고 낡은 인터페이스는 정적 전개이다.
5)
tf.initialize_all_variables() tf.global_variables_initializer()
6) 0.12버전 tensorflow에서 checkpoint 버전을 업데이트하였습니다. 기본적으로 쓰고 읽는 checkpoint는 모두 새로운 V2버전으로 새 버전은restore 과정에서 피크 메모리를 현저히 낮출 수 있습니다.
두 가지 버전 모델은 다음과 같이 저장됩니다.
v1
v2
model.ckpt-66000
model.ckpt-66000.index
model.ckpt-66000.meta
model.ckpt-66000.meta
model.ckpt-66000.data-00000-of-00001
이후 릴리즈의 업데이트에 맞게 모델을 재훈련하여 V2 형식으로 저장합니다.
코드가 다음과 같이 변경되었습니다.
ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.initialize_all_variables())
... 로 바꾸다
checkpoint_file = tf.train.latest_checkpoint(FLAGS.nn_model_dir)
if checkpoint_file is None:
print("Created model with fresh parameters.")
session.run(tf.global_variables_initializer())
else:
print("Reading model parameters from %s" % checkpoint_file)
model.saver.restore(session, checkpoint_file)
7)
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_)) tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
1.0.0 버전의 Tensorflow에서는 명명된 매개변수를 사용하여 호출해야 합니다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
EMNIST에서 알파벳 필기 인식EMNIST-letters를 배웠습니다. CODE: DEMO: — mbotsu (@mb_otsu) 은 2017년에 NIST가 공개한 데이터세트입니다. EMNIST ByClass: 814,255 characters. ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.