fashion_mnist 식별
26306 단어 경사도 가 떨어지다.기계 학습python 프로 그래 밍
relu 대신 relu 6 를 사용 하면 더 좋 은 식별 율(순환 데이터 세트 10 회,정확도:75%-85%)을 얻 을 수 있 습 니 다.
categorical_crossentropy(교차 엔트로피)의 사용 은 식별 의 정확성 을 크게 향상 시 켰 다.
save_weights 저장 가중치,load가중치
import tensorflow as tf
import shutil
import os
from tensorflow.keras import Sequential,layers,datasets,optimizers
import threading
import matplotlib.pyplot as plt
import random
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
Train_times = 11
lr = 0.0001
def Data_set(x,y):
x = tf.cast(x,dtype=tf.float32)/255.
y = tf.cast(y,dtype=tf.int32)
return x,y
(x,y),(t_x,t_y)= datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(Data_set).batch(64)
test_db = tf.data.Dataset.from_tensor_slices((t_x,t_y))
test_db = test_db.map(Data_set)
net = Sequential([
layers.Dense(28*28,activation=tf.nn.relu6),
layers.Dense(256,activation=tf.nn.relu6),
layers.Dense(256,activation=tf.nn.relu6),
layers.Dense(10,activation=tf.nn.relu6)
])
try:
net.load_weights("cpkt")
except:pass
path=r"LOG\TEST"
try:
shutil.rmtree(path)
except:pass
summary = tf.summary.create_file_writer(path)
optimizer = optimizers.Adam(lr =lr)
step=0
threading.Thread(target=os.system,args=(r'tensorboard --logdir LOG',)).start()
for sp in range(1,Train_times):
for S,(x,y) in enumerate(db):
x = tf.reshape(x, shape=(-1, 28 * 28))
with tf.GradientTape() as Tape:
out = net(x)
y = tf.one_hot(y, depth=10)
loss2 = tf.losses.categorical_crossentropy(y,out,from_logits=True)
gradient = Tape.gradient(loss2,net.trainable_variables)
optimizer.apply_gradients(zip(gradient,net.trainable_variables))
if (S%100==0):
print(float(tf.reduce_mean(loss2)))
ok_num=0
test_db = test_db.shuffle(100000)
test = iter(test_db)
for i in range(20):
x,y = test.next()
x = tf.reshape(x,shape=(-1,28*28))
out = net(x)
out = int(tf.argmax(out,axis=-1))
y = int(y)
if(out==y):
ok_num+=1
with summary.as_default():
tf.summary.scalar(' 20 ',float(ok_num/20*100),step=step)
step += 100
print(sp,":",S," Test 20:",ok_num/20*100)
net.save_weights("cpkt")
tf.saved_model.save(net,'model')
xsize=8
ysize=5
n = xsize*ysize
test = iter(test_db)
lable=['T ',' ',' ',' ',' ',' ',' ',' ',' ',' ']
for i in range(1,n,2):
x,y = test.next()
plt.subplot(ysize,xsize,i)
plt.imshow(x,plt.cm.gray)
plt.title(lable[int(y)])
x = tf.reshape(x,shape=(-1,28*28))
result = net(x)
m_p = tf.argmax(result,axis=-1)
y = tf.one_hot([y],depth=10)
i+=1
plt.subplot(ysize, xsize, i)
plt.grid(color='b', ls = '-.', lw = 0.25)
plt.plot(['0','1','2','3','4','5','6','7','8','9'],result.numpy()[-1],color='b',linestyle='-',label='gess')
plt.plot(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], y.numpy()[-1],color='r',linestyle='--',label='real')
plt.title(lable[int(m_p.numpy()[-1])])
plt.show()
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
에러나 실행 완료를 LINE으로 통지 【Python】기계 학습 등을 하고 있으면 1개의 프로그램의 실행에 며칠 걸리는 것은 드물지 않습니다. 프로그램의 실행 상황이 걱정되어 몇 시간 간격으로 단말기를 열었다. 그런 날을 보내지 않았습니까? 그런 사람들을 위해 이번에는...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.