손 으로 가르쳐 줄 게.TensorFlow 2 로 RNN 구현.
10839 단어 TensorFlow2RNN
RNN(Recurrent Netural Network)은 서열 데 이 터 를 처리 하 는 신경 망 이다.서열 데이터 란 앞의 입력 과 뒤의 입력 이 일정한 관 계 를 가 진 다 는 것 이다.
가중치 공유
전통 신경 망:
RNN:
RNN 의 가중치 공 유 는 CNN 의 가중치 공유 와 유사 하 며,시시각각 하나의 가중치 공 유 를 통 해 매개 변수 수 를 크게 줄 였 다.
계산 과정:
계산 상태(State)
계산 출력:
케이스
데이터 세트
IBIM 데이터 세트 는 인터넷 에서 온 50000 건의 영화 에 대한 평론 을 포함 하고 긍정 적 인 평가 와 부정적인 평가 로 나 뉜 다.
RNN 층
class RNN(tf.keras.Model):
def __init__(self, units):
super(RNN, self).__init__()
# [b, 64] (b batch_size)
self.state0 = [tf.zeros([batch_size, units])]
self.state1 = [tf.zeros([batch_size, units])]
# [b, 80] => [b, 80, 100]
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)
self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
# [b, 80, 100] => [b, 64] => [b, 1]
self.out_layer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
"""
:param inputs: [b, 80]
:param training:
:return:
"""
state0 = self.state0
state1 = self.state1
x = self.embedding(inputs)
for word in tf.unstack(x, axis=1):
out0, state0 = self.rnn_cell0(word, state0, training=training)
out1, state1 = self.rnn_cell1(out0, state1, training=training)
# [b, 64] -> [b, 1]
x = self.out_layer(out1)
prob = tf.sigmoid(x)
return prob
데이터 가 져 오기
def get_data():
#
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)
#
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)
#
print(X_train.shape, y_train.shape) # (25000, 80) (25000,)
print(X_test.shape, y_test.shape) # (25000, 80) (25000,)
#
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)
#
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_db = test_db.batch(batch_size, drop_remainder=True)
return train_db, test_db
전체 코드
import tensorflow as tf
class RNN(tf.keras.Model):
def __init__(self, units):
super(RNN, self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batch_size, units])]
self.state1 = [tf.zeros([batch_size, units])]
# [b, 80] => [b, 80, 100]
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)
self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
# [b, 80, 100] => [b, 64] => [b, 1]
self.out_layer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
"""
:param inputs: [b, 80]
:param training:
:return:
"""
state0 = self.state0
state1 = self.state1
x = self.embedding(inputs)
for word in tf.unstack(x, axis=1):
out0, state0 = self.rnn_cell0(word, state0, training=training)
out1, state1 = self.rnn_cell1(out0, state1, training=training)
# [b, 64] -> [b, 1]
x = self.out_layer(out1)
prob = tf.sigmoid(x)
return prob
#
total_words = 10000 #
max_review_len = 80 #
embedding_len = 100 #
batch_size = 1024 #
learning_rate = 0.0001 #
iteration_num = 20 #
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) #
loss = tf.losses.BinaryCrossentropy(from_logits=True) #
model = RNN(64)
# summary
model.build(input_shape=[None, 64])
print(model.summary())
#
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
def get_data():
#
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)
#
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)
#
print(X_train.shape, y_train.shape) # (25000, 80) (25000,)
print(X_test.shape, y_test.shape) # (25000, 80) (25000,)
#
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)
#
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_db = test_db.batch(batch_size, drop_remainder=True)
return train_db, test_db
if __name__ == "__main__":
#
train_db, test_db = get_data()
#
model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)
출력 결과:Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None
(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959
Process finished with exit code 0
여기 서 텐 서 플 로 우 2 를 이용 해 RNN 을 실현 하 는 법 을 알려 드 리 는 이 글 을 소개 합 니 다.텐 서 플 로 우 2 를 실현 하 는 RNN 에 관 한 더 많은 내용 은 저희 의 이전 글 을 검색 하거나 아래 의 관련 글 을 계속 읽 어 주시 기 바 랍 니 다.앞으로 도 많은 사랑 부 탁 드 리 겠 습 니 다!
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
손 으로 가르쳐 줄 게.TensorFlow 2 로 RNN 구현.RNN(Recurrent Netural Network)은 서열 데 이 터 를 처리 하 는 신경 망 이다.서열 데이터 란 앞의 입력 과 뒤의 입력 이 일정한 관 계 를 가 진 다 는 것 이다. RNN: RNN 의 가중...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.