bert, albert의 빠른 훈련과 예측

5685 단어
예훈련 모델이 점점 성숙해지면서 예훈련 모델도 업무에서 더 많이 사용될 것이다. 본고는bert와albert의 신속한 훈련과 배치를 제공했는데 실제로 현재의 예훈련 모델은 사용할 때 대체적으로 같다.
얼마 전에 발표된 중국어 데이터 집합인chineseGLUE를 바탕으로 모든 임무를 4가지 유형으로 나눈다. 그것이 바로 텍스트 분류, 문장 대 판단, 실체 식별, 읽기와 이해이다.같은 종류는 코드를 공유할 수 있으며 위의 네 가지 작업 외에 리딩 to rank를 추가했다.pair wise 방식을 바탕으로 하는 작업이다. 코드는 다음과 같다.https://github.com/jiangxinyang227/bert-for-task.
구체적인 사용은readme 참조
모델이 모든 항목에 정의된 모델입니다.py 파일에서bert와albert의 원본 코드를 직접 호출합니다.py는 예비 훈련 모델을 도입하여 예비 훈련 모델을 encoder 부분으로 할 수도 있고 embedding 층으로만 할 수도 있다. 그리고 스스로 encoder 부분을 정의할 수도 있다. 어쨌든 하류 임무 네트워크 층에 쉽게 접속할 수 있다. 특히 예비 훈련 모델을 embedding 층으로 사용하고 싶을 때 우리는 스스로 encoder 부분을 필요로 한다.
     bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)

        model = modeling.BertModel(config=bert_config,
                                   is_training=self.__is_training,
                                   input_ids=self.input_ids,
                                   input_mask=self.input_masks,
                                   token_type_ids=self.segment_ids,
                                   use_one_hot_embeddings=False)
        output_layer = model.get_pooled_output()

        hidden_size = output_layer.shape[-1].value
        if self.__is_training:
            # I.e., 0.1 dropout
            output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

        with tf.name_scope("output"):
            output_weights = tf.get_variable(
                "output_weights", [self.__num_classes, hidden_size],
                initializer=tf.truncated_normal_initializer(stddev=0.02))

            output_bias = tf.get_variable(
                "output_bias", [self.__num_classes], initializer=tf.zeros_initializer())

            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            self.predictions = tf.argmax(logits, axis=-1, name="predictions")

훈련할 때 예비 훈련의 매개 변수 값을 불러와서 예비 훈련 모델의 변수를 초기화합니다. 구체적으로trainer에 있습니다.py 파일
tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                tvars, self.__bert_checkpoint_path)
print("init bert model params")
tf.train.init_from_checkpoint(self.
__bert_checkpoint_path, assignment_map) print("init bert model params done") sess.run(tf.variables_initializer(tf.global_variables()))

예측할 때predict를 직접 실례화할 수 있다.py 파일 중의predictor 클래스는 checkpoint 모델 파일을 불러오고 클래스 중의predict 방법을 호출하면 예측할 수 있으며 모델 코드 암호화, 모델 최적화 등을 고려하지 않아도 직접 온라인으로 배치할 수 있다.
import json

from predict import Predictor


with open("config/tnews_config.json", "r") as fr:
    config = json.load(fr)


predictor = Predictor(config)
text = " 20       “  ”   ?"
res = predictor.predict(text)
print(res)

좋은 웹페이지 즐겨찾기