소백학Tensorflow의 Logistic 컴백
4728 단어 TensorFlowTensorflow
Tensorflow를 이용하여Logistic 회귀 1위를 실현하고 우리는 먼저 두 개의 함수를 설계하여 후속 프로그램에서 같은 코드를 반복적으로 작성하지 않도록 한다.
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev = 0.01))
def model(X, w):
return tf.matmul(X, w)
둘째, 우리는 mnist의 데이터 집합을 가져왔고 구체적인 방법은 홈페이지를 참고할 수 있다.
#
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
셋째, 손실 함수를 구축하고softmax와 교차 엔트로피를 이용하여 모델을 훈련한다.
# , softmax
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
learning_rate = 0.01
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
전체 코드는 다음과 같습니다.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import input_data
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev = 0.01))
def model(X, w):
return tf.matmul(X, w)
#
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#
X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10])
# w = init_weights([784, 10])
# py_x = model(X, w)
# , softmax
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
learning_rate = 0.01
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
predict_op = tf.argmax(py_x, 1)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
for i in xrange(100):
for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
sess.run(train_op, feed_dict = {X: trX[start:end], Y: trY[start:end]})
print i, np.mean(np.argmax(teY, axis = 1) == sess.run(predict_op, feed_dict = {X: teX, Y: teY}))
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
EMNIST에서 알파벳 필기 인식EMNIST-letters를 배웠습니다. CODE: DEMO: — mbotsu (@mb_otsu) 은 2017년에 NIST가 공개한 데이터세트입니다. EMNIST ByClass: 814,255 characters. ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.