Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu

5615 단어 meachinelearning

Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu 다중 블록 GPU 훈련 모형을 단일 블록 또는 기타 GPU 수량 수요 모형으로 전환


태그:pytorch nn.Dataparalle model.state_dict
참조:reference link

문제 설명


우리가 Pytorch로 모형을 훈련할 때, 서버마다 그래픽 GPU 설정과 수량이 다르고, nn.Dataparallel에 저장된 모델은 그래픽 수량과 연결되어 있습니다. 실제로 우리는 모델이 서로 다른 그래픽 수량의 서버로 마음대로 옮겨 테스트를 실행할 수 있어야 합니다.아래의 코드는 바로 이 문제를 겨냥한 것이다

솔루션


다시 쓰기 nn.Module 내의 state_dict, load_state_dict 함수.즉, 우리가 저장하고 불러오는 모델은 nn을 거치지 않는다는 것이다.DataParallel이 처리했기 때문에 임의의 GPU 수량에서 로드 훈련을 할 수 있습니다.만약 당신이 이미 여러 개의 카드로 훈련을 했다면, 아래의 코드를 복사한 다음, epoch를 실행하면 된다
import torch
import torch.nn as nn
from collections import OrderedDict
from torch.nn.parameter import Parameter

def state_dict(model, destination=None, prefix='', keep_vars=False):
    own_state = model.module if isinstance(model, torch.nn.DataParallel) \
        else model
    if destination is None:
        destination = OrderedDict()
    for name, param in own_state._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.data
    for name, buf in own_state._buffers.items():
        if buf is not None:
            destination[prefix + name] = buf
    for name, module in own_state._modules.items():
        if module is not None:
            state_dict(module, destination, prefix + name + '.', keep_vars=keep_vars)
    return destination

def load_state_dict(model, state_dict, strict=True):
    own_state = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) \
        else model.state_dict()
    for name, param in state_dict.items():
        if name in own_state:
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except Exception:
                raise RuntimeError('While copying the parameter named {}, '
                                    'whose dimensions in the model are {} and '
                                    'whose dimensions in the checkpoint are {}.'
                                    .format(name, own_state[name].size(), param.size()))
        elif strict:
            raise KeyError('unexpected key "{}" in state_dict'
                            .format(name))
    if strict:
        missing = set(own_state.keys()) - set(state_dict.keys())
        if len(missing) > 0:
            raise KeyError('missing keys in state_dict: "{}"'.format(missing))

사용 방법:


###use skill

# before
#  state_dict() , 
model.state_dict()

# now
# , 
state_dict(model)

#before
#  , 
model.load_state_dict(model.state_dict)

#now
#  , 
your_state_dict=state_dict(model)
load_state_dict(model, your_state_dict) 

여러분의 모형이 이제 소가 안 될 때까지 몰아갔으면 좋겠습니다(DZT)

좋은 웹페이지 즐겨찾기