bert, albert의 빠른 훈련과 예측
얼마 전에 발표된 중국어 데이터 집합인chineseGLUE를 바탕으로 모든 임무를 4가지 유형으로 나눈다. 그것이 바로 텍스트 분류, 문장 대 판단, 실체 식별, 읽기와 이해이다.같은 종류는 코드를 공유할 수 있으며 위의 네 가지 작업 외에 리딩 to rank를 추가했다.pair wise 방식을 바탕으로 하는 작업이다. 코드는 다음과 같다.https://github.com/jiangxinyang227/bert-for-task.
구체적인 사용은readme 참조
모델이 모든 항목에 정의된 모델입니다.py 파일에서bert와albert의 원본 코드를 직접 호출합니다.py는 예비 훈련 모델을 도입하여 예비 훈련 모델을 encoder 부분으로 할 수도 있고 embedding 층으로만 할 수도 있다. 그리고 스스로 encoder 부분을 정의할 수도 있다. 어쨌든 하류 임무 네트워크 층에 쉽게 접속할 수 있다. 특히 예비 훈련 모델을 embedding 층으로 사용하고 싶을 때 우리는 스스로 encoder 부분을 필요로 한다.
bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)
model = modeling.BertModel(config=bert_config,
is_training=self.__is_training,
input_ids=self.input_ids,
input_mask=self.input_masks,
token_type_ids=self.segment_ids,
use_one_hot_embeddings=False)
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
if self.__is_training:
# I.e., 0.1 dropout
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
with tf.name_scope("output"):
output_weights = tf.get_variable(
"output_weights", [self.__num_classes, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [self.__num_classes], initializer=tf.zeros_initializer())
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
self.predictions = tf.argmax(logits, axis=-1, name="predictions")
훈련할 때 예비 훈련의 매개 변수 값을 불러와서 예비 훈련 모델의 변수를 초기화합니다. 구체적으로trainer에 있습니다.py 파일
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, self.__bert_checkpoint_path)
print("init bert model params")
tf.train.init_from_checkpoint(self.__bert_checkpoint_path, assignment_map)
print("init bert model params done")
sess.run(tf.variables_initializer(tf.global_variables()))
예측할 때predict를 직접 실례화할 수 있다.py 파일 중의predictor 클래스는 checkpoint 모델 파일을 불러오고 클래스 중의predict 방법을 호출하면 예측할 수 있으며 모델 코드 암호화, 모델 최적화 등을 고려하지 않아도 직접 온라인으로 배치할 수 있다.
import json
from predict import Predictor
with open("config/tnews_config.json", "r") as fr:
config = json.load(fr)
predictor = Predictor(config)
text = " 20 “ ” ?"
res = predictor.predict(text)
print(res)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
다양한 언어의 JSONJSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다. 그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다. 저는 일반적으로 '객체'{}...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.