pytorch 동결 부분 매개 변수 훈련 다른 부분 실현
1) 다음 문장을 모델에 추가
for p in self.parameters():
p.requires_grad = False
예를 들어resenet 예비훈련 모형을 불러온 후resenet을 바탕으로 새로운 모형을 연결했습니다. resenet 모듈의 그 부분은 잠시 동결하고 업데이트하지 않고 다른 부분의 매개 변수만 업데이트할 수 있습니다. 그러면 아래에 위의 문장을 추가할 수 있습니다.
class RESNET_MF(nn.Module):
def __init__(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False # ,
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...
최적화기에 추가:
filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, \
betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
2) 매개 변수를 질서정연한 사전에 저장하면 매개 변수의 이름에 대응하는 id값을 찾아 동결할 수 있다
각 레이어의 코드를 보려면 다음과 같이 하십시오.
model_dict = torch.load('net.pth.tar').state_dict()
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
print(i, p)
이 파일을 인쇄하면 대략 이 모양을 볼 수 있다.
0 gamma
1 resnet.conv1.weight
2 resnet.bn1.weight
3 resnet.bn1.bias
4 resnet.bn1.running_mean
5 resnet.bn1.running_var
6 resnet.layer1.0.conv1.weight
7 resnet.layer1.0.bn1.weight
8 resnet.layer1.0.bn1.bias
9 resnet.layer1.0.bn1.running_mean
....
모델에도 다음과 같은 코드를 추가합니다.
for i,p in enumerate(net.parameters()):
if i < 165:
p.requires_grad = False
최적화기에 위의 그 말을 추가하면 매개 변수의 차단을 실현할 수 있다보충:pytorch 탑재 예비훈련모형+단점회복+동결훈련(피갱버전)
1. 모형의 네트워크 구조를 미리 훈련한다 = 모형의 네트워크 구조를 불러와야 한다
그럼 그냥 써요.
path=" .pt "
model = " "
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint)
2. 예비 훈련 모델의 네트워크 구조와 당신의 네트워크 구조가 일치하지 않는다
위의 공식을 직접 사용하면 unexpected key 모듈과 유사합니다.xxx.무게 문제
이런 상황에서 인터넷 정보를 구체적으로 분석한 다음에 어떻게 불러올지 결정해야 한다.
# model_dict , ,
model_dict = model.state_dict()
print(model_dict.keys()
#
checkpoint = torch.load(path,map_location=device)
for k, v in checkpoint.items():
print("keys:".k)
# , 【 】 。
그리고 두 네트워크 구조 파라미터의 공통점과 차이점을 비교하여만약 각 층의 네트워크 명칭이 기본적으로 일치하지 않는다면, 이 예비 훈련 모델은 기본적으로 사용할 수 없으니, 직접 모델을 바꾸어라
만약 양자 네트워크 파라미터가 유사한 점이 많지만 완전히 일치하지 않는다면 다음과 같은 방식을 취할 수 있다.
(1) 일부 네트워크 키워드 - 완전히 일치하는 경우
model.load_state_dict(checkpoint, strict=True)
load_state_dict 함수에 매개 변수strict=True를 추가합니다. 없는 dict를 무시하고 같은 것이 있으면 복사하고 없으면 값을 포기합니다!그는 예비 훈련 모델의 키워드는 반드시 확실하게 네트워크의state_와 엄격해야 한다고 요구한다dict () 함수가 반환하는 키워드가 일치해야 값을 부여할 수 있습니다.strict도 지능적이지 않아서 인터넷 키워드가 기본적으로 일치할 수 있는 상황에 적용된다.그렇지 않으면 로드가 성공해도 네트워크 매개 변수가 비어 있습니다.
(2) 대부분의 네트워크 키워드 --- 일부 일치(완전히 같지는 않지만 유사), 예를 들어
네트워크 키워드:backbone.stage0.rbr_dense.conv.weight
사전 훈련 모델 키워드:stage0.rbr_dense.conv.weight
이를 통해 알 수 있듯이 네트워크 키워드는 예비 훈련 모델보다 접두사가 하나 더 많고 다른 것은 완전히 일치한다. 이런 상황에서 예비 훈련 모델의stage0을 사용할 수 있다.rbr_dense.conv.weight 네트워크의 백본을 읽습니다.stage0.rbr_dense.conv.weight 중.
# ,in not in key
model_dict = model.state_dict()
checkpoint = torch.load(path,map_location=device)
# k , ss
for k, v in checkpoint.items():
flag = False
for ss in model_dict.keys():
if k in ss: #
s = ss; flag = True; break
else:
continue
if flag:
checkpoint[k] = model_dict[s]
3, 인터럽트 복구
나는 이것이 일반적인 [모델 저장 로드] 방법과 차이가 주로 epoch의 회복이라고 생각한다
#
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
... # ,
}
torch.save(state, filepath)
# ,
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
start_epoch = checkpoint['epoch'] + 1
4. 동결 훈련
일반적으로 동결 훈련은 [backbone]을 대상으로 하는 것으로 [이전 학습]에 비교적 많이 응용된다
예를 들어 0-49 Epoch: backbone을 동결하여 훈련한다.50-99: 훈련을 동결하지 않는다.
Init_Epoch = 0
Freeze_Epoch = 50
Unfreeze_Epoch =100
#------------------------------------#
#
#------------------------------------#
for param in model.backbone.parameters():
param.requires_grad = False
for epoch in range(Init_Epoch,Freeze_Epoch):
# I`m Freeze-training !!
pass
#------------------------------------#
#
#------------------------------------#
for param in model.backbone.parameters():
param.requires_grad = True
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
# I`m unfreeze-training !!
pass
이상의 개인적인 경험으로 여러분께 참고가 되었으면 좋겠습니다. 또한 많은 응원 부탁드립니다.만약 잘못이 있거나 완전한 부분을 고려하지 않으신다면 아낌없이 가르침을 주시기 바랍니다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
정확도에서 스케일링의 영향데이터셋 스케일링은 데이터 전처리의 주요 단계 중 하나이며, 데이터 변수의 범위를 줄이기 위해 수행됩니다. 이미지와 관련하여 가능한 최소-최대 값 범위는 항상 0-255이며, 이는 255가 최대값임을 의미합니다. 따...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.