fashion_mnist 식별

fashion_mnist 식별
relu 대신 relu 6 를 사용 하면 더 좋 은 식별 율(순환 데이터 세트 10 회,정확도:75%-85%)을 얻 을 수 있 습 니 다.
categorical_crossentropy(교차 엔트로피)의 사용 은 식별 의 정확성 을 크게 향상 시 켰 다.
save_weights 저장 가중치,load가중치
import tensorflow as tf
import shutil
import os
from tensorflow.keras import Sequential,layers,datasets,optimizers
import threading
import matplotlib.pyplot as plt
import random
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
Train_times = 11
lr = 0.0001
def Data_set(x,y):
    x = tf.cast(x,dtype=tf.float32)/255.
    y = tf.cast(y,dtype=tf.int32)
    return x,y
(x,y),(t_x,t_y)= datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(Data_set).batch(64)
test_db = tf.data.Dataset.from_tensor_slices((t_x,t_y))
test_db = test_db.map(Data_set)
net = Sequential([
    layers.Dense(28*28,activation=tf.nn.relu6),
    layers.Dense(256,activation=tf.nn.relu6),
    layers.Dense(256,activation=tf.nn.relu6),
    layers.Dense(10,activation=tf.nn.relu6)
])
try:
    net.load_weights("cpkt")
except:pass
path=r"LOG\TEST"
try:
    shutil.rmtree(path)
except:pass
summary = tf.summary.create_file_writer(path)
optimizer = optimizers.Adam(lr =lr)
step=0
threading.Thread(target=os.system,args=(r'tensorboard --logdir LOG',)).start()
for sp in range(1,Train_times):
    for S,(x,y) in enumerate(db):
        x = tf.reshape(x, shape=(-1, 28 * 28))
        with tf.GradientTape() as Tape:
            out = net(x)
            y = tf.one_hot(y, depth=10)
            loss2 = tf.losses.categorical_crossentropy(y,out,from_logits=True)
        gradient = Tape.gradient(loss2,net.trainable_variables)
        optimizer.apply_gradients(zip(gradient,net.trainable_variables))
        if (S%100==0):
            print(float(tf.reduce_mean(loss2)))
            ok_num=0
            test_db = test_db.shuffle(100000)
            test = iter(test_db)
            for i in range(20):
                x,y = test.next()
                x = tf.reshape(x,shape=(-1,28*28))
                out = net(x)
                out = int(tf.argmax(out,axis=-1))
                y = int(y)
                if(out==y):
                    ok_num+=1
            with summary.as_default():
                tf.summary.scalar(' 20       ',float(ok_num/20*100),step=step)
                step += 100
            print(sp,":",S,"   Test 20:",ok_num/20*100)
net.save_weights("cpkt")
tf.saved_model.save(net,'model')
xsize=8
ysize=5
n = xsize*ysize
test = iter(test_db)
lable=['T ','  ','   ','   ','  ','  ','  ','   ',' ','  ']
for i in range(1,n,2):
    x,y = test.next()
    plt.subplot(ysize,xsize,i)
    plt.imshow(x,plt.cm.gray)
    plt.title(lable[int(y)])
    x = tf.reshape(x,shape=(-1,28*28))
    result = net(x)
    m_p = tf.argmax(result,axis=-1)
    y = tf.one_hot([y],depth=10)
    i+=1
    plt.subplot(ysize, xsize, i)
    plt.grid(color='b', ls = '-.', lw = 0.25)
    plt.plot(['0','1','2','3','4','5','6','7','8','9'],result.numpy()[-1],color='b',linestyle='-',label='gess')
    plt.plot(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], y.numpy()[-1],color='r',linestyle='--',label='real')
    plt.title(lable[int(m_p.numpy()[-1])])
plt.show()

좋은 웹페이지 즐겨찾기