Python은 주류 프레임워크tensorflow를 이용하여 BP 네트워크 의합을 구축한다

2171 단어
주로 선생님을 귀찮게 하지 않는 몇 가지 프로그램을 참고했는데 주로 데이터를 귀일화하는 것 같았다. 특히 y의 수치를 귀일화시켰다. 그렇지 않으면 훈련 과정에서loss가tensorflow로 내려가지 못할 것 같았다.
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  8 11:56:47 2018

@author: Administrator
"""

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

#       ,      
x = np.linspace(-2, 2, 1000)[:, np.newaxis]
y = 5 * x + np.sin(2 * 3.14 * x) + 2

# x = train_X
# y = train_Y
#####s    
np.random.seed(2019)
n_index = np.random.permutation(len(x))

# ###########               
train_x_disorder = x[n_index[0:800]]
test_x_disorder = x[n_index[800:len(x) + 1]]
train_y_disorder = y[n_index[0:800]]
test_y_disorder = y[n_index[800:len(x) + 1]]


train_X = train_x_disorder
train_Y = train_y_disorder


tf.set_random_seed(2019)
#     
#    
X = tf.placeholder("float", [None, 1])
Y = tf.placeholder("float", [None, 1])
#     
W1 = tf.Variable(tf.random_normal([1, 80]), name="weight")
b1 = tf.Variable(tf.ones([1, 80]), name="bias")
W3 = tf.Variable(tf.random_normal([80, 1]), name="weight")
b3 = tf.Variable(tf.ones([1]), name="bias")

#     
z1 = tf.matmul(X, W1) + b1
z2 = tf.nn.relu(z1)
z5 = tf.matmul(z2, W3) + b3

#     
cost = tf.reduce_mean(tf.square(Y - z5))
# optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  # Gradient descent
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)

#      
init = tf.global_variables_initializer()
#     
training_epochs = 10000
display_step = 100
loss_dis = []

#   session
sess = tf.Session()
sess.run(init)


for epoch in range(training_epochs + 1):
    sess.run(optimizer, feed_dict={X: train_X, Y: train_Y})

    #           
    if epoch % display_step == 0:
        loss = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
        print("Epoch:", epoch, "cost=", loss)
        loss_dis.append(loss)
print(" Finish")


#     

pre = sess.run(z5,feed_dict={X:train_X})
plt.figure()
plt.plot(train_X,pre,'o',train_X,train_Y,'*')

error = pre-train_Y
plt.figure()
plt.plot(error)



좋은 웹페이지 즐겨찾기