keras 기반 EfficientNet 재현[2 트레이닝 모듈]
5905 단어 딥 러닝cvtensorflow기계 학습
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import math
import EfficientNet_Model
import numpy as np
from Data_Channel import Data_Channel
import random
BatchSize = 50
Resolution = 260
EPOCHS = 3000
Save_n_Epoch = 50
ALL_images = 10000
Save_DIR = "./ModelLog/"
Labels_OH = {"cloudy":np.ones(BatchSize), "haze":np.full((BatchSize,),2),
"rainy":np.full((BatchSize,),3), "snow":np.full((BatchSize,),4),
"sunny":np.full((BatchSize,),5), "thunder":np.zeros(BatchSize)}
Valid_OH = {"cloudy":np.ones(20), "haze":np.full((20,),2),
"rainy":np.full((20,),3), "snow":np.full((20,),4),
"sunny":np.full((20,),5), "thunder":np.zeros(20)}
DC_Dic = {"cloudy":Data_Channel(category="cloudy", pool_size=BatchSize, resolution=Resolution),
"haze":Data_Channel(category="haze", pool_size=BatchSize, resolution=Resolution),
"rainy":Data_Channel(category="rainy", pool_size=BatchSize, resolution=Resolution),
"snow":Data_Channel(category="snow", pool_size=BatchSize, resolution=Resolution),
"sunny":Data_Channel(category="sunny", pool_size=BatchSize, resolution=Resolution),
"thunder":Data_Channel(category="thunder", pool_size=BatchSize, resolution=Resolution)}
DC_list = ["cloudy", "haze", "rainy", "snow", "sunny", "thunder"]
'''def process_features(features, data_augmentation):
image_raw = features['image_raw'].numpy()
image_tensor_list = []
for image in image_raw:
image_tensor = load_and_preprocess_image(image, data_augmentation=data_augmentation)
image_tensor_list.append(image_tensor)
images = tf.stack(image_tensor_list, axis=0)
labels = features['label'].numpy()
return images, labels'''
if __name__ == '__main__':
# GPU settings
gpus = tf.config.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# create model
model = EfficientNet_Model.efficient_net_b2()
# define loss and optimizer, label must be given by round number!
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.RMSprop()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
# @tf.function
def train_step(image_batch, label_batch):
with tf.GradientTape() as tape:
predictions = model(image_batch, training=True)
loss = loss_object(y_true=label_batch, y_pred=predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
train_loss.update_state(values=loss)
train_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# @tf.function
def valid_step(image_batch, label_batch):
predictions = model(image_batch, training=False)
v_loss = loss_object(label_batch, predictions)
valid_loss.update_state(values=v_loss)
valid_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# start training
for epoch in range(EPOCHS):
for step in range(round(ALL_images/BatchSize)):
Category = random.choice(DC_list)
Channel_now = DC_Dic[Category]
Channel_now.Renew_dataset()
train_step(Channel_now.RF_pool, Labels_OH[Category])
print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch,
EPOCHS,
step,
5,
train_loss.result().numpy(),
train_accuracy.result().numpy()))
Channel_now.Renew_Valid_ds()
valid_step(Channel_now.Valid_pool, Valid_OH[Category])
print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
"valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch,
EPOCHS,
train_loss.result().numpy(),
train_accuracy.result().numpy(),
valid_loss.result().numpy(),
valid_accuracy.result().numpy()))
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()
if epoch % Save_n_Epoch == 0:
model.save_weights(filepath=Save_DIR+"epoch-{}".format(epoch), save_format='tf')
# save weights
model.save_weights(filepath=Save_DIR+"model", save_format='tf')
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
keras 기반 EfficientNet 재현[2 트레이닝 모듈]B3는 일반적인 서버가 달리기에 적합할 것 같아요.batchsize는 너무 크게 설정하지 마세요. 메모리를 초과하기 쉬워요.이것은 앞의 AlexNet 데이터 집합을 사용하고 다음에 데이터 집합 주소를 넣으세요. 트림...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.