Tensorflow Hub 및 Keras를 사용하여 NLP에서 전이 학습

Tensorflow 2.0은 Keras를 모델 구축을 위한 기본 고급 API로 도입했습니다. Tensorflow Hub의 사전 훈련된 모델과 결합하여 NLP에서 전이 학습을 위한 매우 간단한 방법을 제공하여 즉시 사용할 수 있는 좋은 모델을 생성합니다.



프로세스를 설명하기 위해 기사 제목이 클릭베이트인지 여부를 분류하는 예를 들어 보겠습니다.

데이터 준비



사용 가능한 'Stop Clickbait: Detecting and Preventing Clickbaits in Online News Media' 논문here의 데이터 세트를 사용합니다.

이 기사의 목표는 전이 학습을 설명하는 것이므로 이미 사전 처리된 데이터 세트를 pandas 데이터 프레임에 직접 로드합니다.

import pandas as pd
df = pd.read_csv('http://bit.ly/clickbait-data')


데이터 세트는 페이지 제목과 레이블로 구성됩니다. 제목이 클릭베이트인 경우 레이블은 1입니다.



데이터를 70% 훈련 데이터와 30% 검증 데이터로 나누겠습니다.

from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(df['title'], 
                                                    df['label'], 
                                                    test_size=0.3, 
                                                    stratify=df['label'], 
                                                    random_state=42)


모델 아키텍처



이제 pip를 사용하여 tensorflow와 tensorflow-hub를 설치합니다.

pip install tensorflow-hub
pip install tensorflow==2.1.0


텍스트 데이터를 모델의 기능으로 사용하려면 이를 숫자 형식으로 변환해야 합니다. Tensorflow Hub는 문장을 BERT, NNLM 및 Wikiwords와 같은 임베딩으로 변환하기 위한 다양한modules을 제공합니다.

Universal Sentence Encoder는 문장 임베딩을 생성하는 데 널리 사용되는 모듈 중 하나입니다. 텍스트에 대해 512 고정 크기 벡터를 반환합니다.
아래는 텐서플로우 허브를 사용하여 "Hello World"문장의 임베딩을 캡처하는 방법의 예입니다.



import tensorflow_hub as hub

encoder = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
encoder(['Hello World'])




Tensorflow 2.0에서는 새로운 hub.KerasLayer 모듈 덕분에 모델에 이러한 임베딩을 사용하는 것이 식은 죽 먹기입니다. 클릭베이트 감지의 이진 분류 작업을 위한 tf.keras 모델을 설계해 보겠습니다.

먼저 필요한 라이브러리를 가져옵니다.

import tensorflow as tf
import tensorflow_hub as hub


그런 다음 레이어를 캡슐화할 순차 모델을 만듭니다.

model = tf.keras.models.Sequential()


첫 번째 레이어는 tfhub.dev에서 사용 가능한 모델을 로드할 수 있는 hub.KerasLayer가 될 것입니다. Universal Sentence Encoder을 로드할 예정입니다.

model.add(hub.KerasLayer('https://tfhub.dev/google/universal-sentence-encoder/4', 
                        input_shape=[], 
                        dtype=tf.string, 
                        trainable=True))


사용된 다양한 매개변수의 의미는 다음과 같습니다.
  • /4 : Hub에 있는 Universal Sentence Encoder의 변형을 나타냅니다. 우리는 Deep Averaging Network (DAN) 변형을 사용하고 있습니다. Transformer architecture 및 기타variants도 있습니다.
  • input_shape=[] : 데이터에 특징이 없고 텍스트 자체만 있으므로 특징 차원이 비어 있습니다.
  • dtype=tf.string : 원시 텍스트 자체를 모델에 전달할 것이므로
  • trainable=True : USE를 미세 조정할지 여부를 나타냅니다. True로 설정하면 USE에 있는 임베딩이 다운스트림 작업에 따라 미세 조정됩니다.

  • 다음으로 단일 노드가 있는 Dense 레이어를 추가하여 0과 1 사이의 클릭베이트 확률을 출력합니다.

    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    


    요약하면, 우리는 텍스트 데이터를 가져와 512차원 임베딩으로 투영하고 클릭 베이트 확률을 제공하기 위해 시그모이드 활성화가 있는 피드포워드 신경망을 통해 전달하는 모델을 가지고 있습니다.



    또는 tf.keras 기능 API를 사용하여 위의 정확한 아키텍처를 구현할 수도 있습니다.

    x = tf.keras.layers.Input(shape=[], dtype=tf.string)
    y = hub.KerasLayer('https://tfhub.dev/google/universal-sentence-encoder/4', 
                        trainable=True)(x)
    z = tf.keras.layers.Dense(1, activation='sigmoid')(y)
    model = tf.keras.models.Model(x, z)
    


    모델 요약의 출력은 다음과 같습니다.

    model.summary()
    




    Universal Sentence Encoder를 미세 조정하고 있기 때문에 훈련 가능한 매개변수의 수는 256,798,337입니다.

    모델 훈련



    이진 분류 작업을 수행하고 있으므로 ADAM 옵티마이저 및 정확도와 함께 이진 교차 엔트로피 손실을 메트릭으로 사용합니다.

    model.compile(optimizer='adam', 
                  loss='binary_crossentropy', 
                  metrics=['accuracy'])
    


    이제 모델을 훈련시켜 봅시다.

    model.fit(x_train, 
              y_train, 
              epochs=2, 
              validation_data=(x_test, y_test))
    


    우리는 단 2개의 에폭으로 99.62%의 훈련 정확도와 98.46%의 검증 정확도에 도달했습니다.

    추론



    몇 가지 예에서 모델을 테스트해 보겠습니다.

    # Clickbait
    >> model.predict(["21 Pictures That Will Make You Feel Like You're 99 Years Old"])
    array([[0.9997924]], dtype=float32)
    
    # Not Clickbait
    >> model.predict(['Google announces TensorFlow 2.0'])
    array([[0.00022611]], dtype=float32)
    


    결론



    따라서 Tensorflow Hub와 tf.keras의 조합으로 전이 학습을 쉽게 활용하고 모든 다운스트림 작업을 위한 고성능 모델을 구축할 수 있습니다.

    데이터 크레딧


    Abhijnan Chakraborty, Bhargavi Paranjape, Sourya Kakarla, and Niloy Ganguly. "Stop Clickbait: Detecting and Preventing Clickbaits in Online News Media”. In Proceedings of the 2016 IEEE/ACM International Conference on Advances in Social Networks Analysis and Mining (ASONAM), San Fransisco, US, August 2016

    연결하다



    이 블로그 게시물이 마음에 드셨다면 매주 새 블로그 게시물을 공유하는 곳에서 저와 연락해 주시기 바랍니다.

    좋은 웹페이지 즐겨찾기