Transformers 저장 및 로드 모델 | 8
3702 단어 인공지능
이 절 은 마이크로 모델 (BERT, GPT, GPT - 2, Transformer - XL) 을 저장 하고 다시 불 러 오 는 방법 을 설명 합 니 다.세 가지 파일 형식 을 저장 해 야 미조정 모델 을 다시 불 러 올 수 있 습 니 다.
pytorch_model.bin
config.json
vocab.txt
BERT 와 Transformer - XL, vocab.json
GPT / GPT - 2 (BPE 어휘), merges.txt
. 이것 은 모델, 설정, 프로필 을 저장 하 는 추천 방법 입 니 다. 단 어 를
output_dir
디 렉 터 리 에 저장 한 다음 모델 과 tokenizer 를 다시 불 러 옵 니 다.from transformers import WEIGHTS_NAME, CONFIG_NAME
output_dir = "./models/"
# 1: 、
# ,
# PyTorch DistributedDataParallel DataParallel
model_to_save = model.module if hasattr(model, 'module') else model
# , `from_pretrained`
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
# 2:
#Bert
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
#GPT
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
각 유형의 파일 에 특정 경 로 를 사용 하려 면 다른 방법 으로 모델 을 저장 하고 다시 불 러 올 수 있 습 니 다.
output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"
# 1: 、
# ,
# PyTorch DistributedDataParallel DataParallel
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# 2:
# 、 , `from_pretrained` 。
# :
#Bert
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
#GPT
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
링크: https://huggingface.co/transf...
반 창 AI 블 로그 사이트 에 오신 것 을 환영 합 니 다: http://panchuang.net/
OpenCV 중국어 공식 문서: http://woshicver.com/
반 창 블 로그 자원 집합 소 에 오신 것 을 환영 합 니 다: http://docs.panchuang.net/