Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu
Pytorch nn.Dataparallel model state_dict multi-gpu to single-gpu 다중 블록 GPU 훈련 모형을 단일 블록 또는 기타 GPU 수량 수요 모형으로 전환
태그:pytorch nn.Dataparalle model.state_dict
참조:reference link
문제 설명
우리가 Pytorch로 모형을 훈련할 때, 서버마다 그래픽 GPU 설정과 수량이 다르고, nn.Dataparallel에 저장된 모델은 그래픽 수량과 연결되어 있습니다. 실제로 우리는 모델이 서로 다른 그래픽 수량의 서버로 마음대로 옮겨 테스트를 실행할 수 있어야 합니다.아래의 코드는 바로 이 문제를 겨냥한 것이다
솔루션
다시 쓰기 nn.Module 내의 state_dict, load_state_dict 함수.즉, 우리가 저장하고 불러오는 모델은 nn을 거치지 않는다는 것이다.DataParallel이 처리했기 때문에 임의의 GPU 수량에서 로드 훈련을 할 수 있습니다.만약 당신이 이미 여러 개의 카드로 훈련을 했다면, 아래의 코드를 복사한 다음, epoch를 실행하면 된다
import torch
import torch.nn as nn
from collections import OrderedDict
from torch.nn.parameter import Parameter
def state_dict(model, destination=None, prefix='', keep_vars=False):
own_state = model.module if isinstance(model, torch.nn.DataParallel) \
else model
if destination is None:
destination = OrderedDict()
for name, param in own_state._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in own_state._buffers.items():
if buf is not None:
destination[prefix + name] = buf
for name, module in own_state._modules.items():
if module is not None:
state_dict(module, destination, prefix + name + '.', keep_vars=keep_vars)
return destination
def load_state_dict(model, state_dict, strict=True):
own_state = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) \
else model.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
사용 방법:
###use skill
# before
# state_dict() ,
model.state_dict()
# now
# ,
state_dict(model)
#before
# ,
model.load_state_dict(model.state_dict)
#now
# ,
your_state_dict=state_dict(model)
load_state_dict(model, your_state_dict)
여러분의 모형이 이제 소가 안 될 때까지 몰아갔으면 좋겠습니다(DZT)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
로컬 스토리지의 기초이 데이터는 현재 작업 중인 브라우저에서만 사용할 수 있으며 페이지를 다시 로드하면 쉽게 액세스할 수 있습니다. 로컬 저장소는 사용자의 브라우저에서 사용할 수 있는 키/값 데이터베이스입니다. 이 데이터베이스는 무한정...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.