RNN은 손으로 쓴 숫자의 실제 식별을 실현한다
2334 단어 tensorflowGoogle실전
import tensorflow as tf
#
from tensorflow.examples.tutorials.mnist import input_data
tf.reset_default_graph()
mnist=input_data.read_data_sets("./temp/data",one_hot=True)
trainings,trainlabels,testimgs,testlabels \
=mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels
ntrain,ntest,dim,nclasses =trainings.shape[0],testimgs.shape[0],trainings.shape[1], \
trainlabels.shape[1]
print(ntrain,ntest,dim,nclasses)
# 55000 10000 784 10
#
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)
def RNN(X_holder):
reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
last_cell = cell_list[-1]
Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
predict_Y = tf.matmul(last_cell, Weights) + biases
return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)
isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
display_step=1
epochs=20
print("Start optimization")
for epoch in range(epochs):
avg_cost=0
# total_batch=int(mnist.train.num_examples/batch_size)
total_batch=100
for i in range(total_batch):
X, Y = mnist.train.next_batch(batch_size)
if epoch%display_step==0:
print("Epoch:%03d/%03d: "%(epoch,epochs))
train_acc=session.run(accuracy,feed_dict={X_holder: X, Y_holder: Y})
print("Training accuracy:%.3f"%(train_acc))
feeds = {X_holder: testimgs, Y_holder: testlabels}
test_acc = session.run(accuracy, feed_dict=feeds)
print("Testing accuracy:%.3f" % (test_acc))
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Mediapipe를 사용한 맞춤형 인간 포즈 분류OpenCV의 도움으로 Mediapipe를 사용하여 사용자 지정 포즈 분류 만들기 Yoga Pose Dataset을 사용하여 사용자 정의 인간 포즈 분류를 생성하겠습니다. 1. 리포지토리 복제: 데이터세트 다운로드:...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.