TensorFlow Mnist 데이터 세트 논리 회귀 분류 작업

2397 단어 TensorFlow
필요한 키트를 가져오려면 다음과 같이 하십시오.
#  Mnist 
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets('data/', one_hot=True)

실행 결과:
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz

매개변수를 설정하려면 다음과 같이 하십시오.
#  
numClasses = 10
#  28*28*1
inputSize = 784
# 5W 
trainingIterations = 50000
#  , 64 
batchSize = 64

X, y 크기 지정:
X = tf.placeholder(tf.float32, shape = [None, inputSize])
y = tf.placeholder(tf.float32, shape = [None, numClasses])

매개 변수 초기화:
#  784*10 W ,stddev: , 1.0。
W1 = tf.Variable(tf.random_normal([inputSize, numClasses], stddev=0.1))
B1 = tf.Variable(tf.constant(0.1), [numClasses])
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

구조 모델:
y_pred = tf.nn.softmax(tf.matmul(X, W1) + B1)

loss = tf.reduce_mean(tf.square(y - y_pred))
opt = tf.train.GradientDescentOptimizer(learning_rate = .05).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

반복 계산:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(trainingIterations):
    batch = mnist.train.next_batch(batchSize)
    batchInput = batch[0]
    batchLabels = batch[1]
    _, trainingLoss = sess.run([opt, loss], feed_dict={X: batchInput, y: batchLabels})
    if i%10000 ==0:
        train_accuracy = accuracy.eval(session=sess, feed_dict={X: batchInput, y: batchLabels})
        print("step %d, training accuracy %g"%(i, train_accuracy))

실행 결과:
step 0, training accuracy 0.125
step 10000, training accuracy 0.90625
step 20000, training accuracy 0.875
step 30000, training accuracy 0.875
step 40000, training accuracy 0.890625

테스트 결과:
batch = mnist.test.next_batch(batchSize)
testAccuracy = sess.run(accuracy, feed_dict={X: batch[0], y: batch[1]})
print("test accuracy %g"%(testAccuracy))

실행 결과:
test accuracy 0.90625

좋은 웹페이지 즐겨찾기