pytorch 노트 02) 모형의 저장 및 불러오기

9453 단어 기계·딥러닝
전체 모델 저장 및 로드
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

모델 매개변수만 저장 및 로드(권장, 미리 수동으로 모델 구성 필요)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

그러나 몇 가지 세부 사항에 주의해야 한다.nn을 사용하면DataParallel은 한 컴퓨터에 여러 개의 GPU를 사용하는데load모델을 사용할 때도 반드시 DataParallel을 먼저 사용해야 한다. 이것은keras와 유사하다.
2. load는 많은 무거운 짐을 싣는 기능을 제공하는데 GPU에서 훈련하는 권한을 CPU에 싣고 뛸 수 있다.컨텐트 참조:https://www.ptorch.com/news/74.html
torch.load('tensors.pt')
#          CPU 
torch.load('tensors.pt', map_location=lambda storage, loc: storage)
#          GPU 1 
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
#     GPU 1     GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

cpu에 미리 훈련된 GPU 모형을 탑재하고 모든 GPU 장량을 CPU에 강제하는 방식이 있다.
torch.load('my_file.pt', map_location=lambda storage, loc: storage)

상기 코드는 모델이 하나의 GPU에서 훈련할 때만 작용한다.만약 내가 여러 GPU에서 내 모형을 훈련시키고 저장한 다음 CPU에 불러오려고 한다면, 이 오류를 얻었습니다: Key Error: 'unexpected key' module.conv1.weight 'in statedict'는 어떻게 해결합니까?모델을 사용하여 모델을 저장했을 수도 있습니다nn.DataParallel, 이 모델은 이 모델에 모듈을 저장합니다. 현재 모델 DataParallel을 불러오려고 시도하고 있습니다.당신은 nn.DataParallel은 네트워크에 불러오는 목적을 잠시 추가하거나, 무거운 파일을 불러올 수도 있으며, 모듈 접두사가 없는 새로운 질서정연한 사전을 만들고 불러올 수도 있습니다.
# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

필자는 간단한 함수를 봉하여 다중 GPU 권한을 CPU에 직접 불러올 수 있다(일치하는 권한만 불러올 수 있다)
#     ,            ,    gpu  
def load_state_keywise(model, model_path):
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location='cpu')
    key = list(pretrained_dict.keys())[0]
    # 1. filter out unnecessary keys
    # 1.1 multi-GPU ->CPU
    if (str(key).startswith('module.')):
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if
                           k[7:] in model_dict and v.size() == model_dict[k[7:]].size()}
    else:
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                           k in model_dict and v.size() == model_dict[k].size()}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

'module'를 왜 빼냐고 물어보시네요.싱글 GPU에서 뛸 수 있어요. 밑에 있는 밤을 보세요.
import torch
from torch import nn
import torchvision
#  alexnet   ,    GPU CUP 
alexnet=torchvision.models.alexnet()
state_dict=alexnet.state_dict()
for k, v in state_dict.items():
    print(k)
print("-"*20)  #      
#   GPU
model = nn.DataParallel(alexnet)
state_dict=model.state_dict()
for k, v in state_dict.items():
    print(k)

결과를 보면 알 수 있듯이 접두사'module.'가 많아졌다.
features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.10.weight
features.10.bias
classifier.1.weight
classifier.1.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias
--------------------
module.features.0.weight
module.features.0.bias
module.features.3.weight
module.features.3.bias
module.features.6.weight
module.features.6.bias
module.features.8.weight
module.features.8.bias
module.features.10.weight
module.features.10.bias
module.classifier.1.weight
module.classifier.1.bias
module.classifier.4.weight
module.classifier.4.bias
module.classifier.6.weight
module.classifier.6.bias

좋은 웹페이지 즐겨찾기