RNN은 손으로 쓴 숫자의 실제 식별을 실현한다

import tensorflow as tf

# 
from tensorflow.examples.tutorials.mnist import input_data
tf.reset_default_graph()
mnist=input_data.read_data_sets("./temp/data",one_hot=True)
trainings,trainlabels,testimgs,testlabels \
    =mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels
ntrain,ntest,dim,nclasses =trainings.shape[0],testimgs.shape[0],trainings.shape[1], \
    trainlabels.shape[1]

print(ntrain,ntest,dim,nclasses)
# 55000 10000 784 10
#


learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

def RNN(X_holder):
    reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
    cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
    last_cell = cell_list[-1]
    Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
    biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    predict_Y = tf.matmul(last_cell, Weights) + biases
    return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))

display_step=1
epochs=20
print("Start optimization")
for epoch in range(epochs):
    avg_cost=0
    # total_batch=int(mnist.train.num_examples/batch_size)
    total_batch=100
    for i in range(total_batch):
        X, Y = mnist.train.next_batch(batch_size)

    
        if epoch%display_step==0:
            print("Epoch:%03d/%03d: "%(epoch,epochs))

            train_acc=session.run(accuracy,feed_dict={X_holder: X, Y_holder: Y})
            print("Training accuracy:%.3f"%(train_acc))

            feeds = {X_holder: testimgs, Y_holder: testlabels}
            test_acc = session.run(accuracy, feed_dict=feeds)
            print("Testing accuracy:%.3f" % (test_acc))




좋은 웹페이지 즐겨찾기