PyTorch 학습 시리즈(10) - 어떻게 훈련할 때 층을 고정합니까?

때때로 우리는 다른 임무(예를 들어 분류)로 네트워크를 미리 훈련한 다음에 권적층을 이미지 특징 추출기로 고정시킨 다음에 현재 임무의 데이터로 전체 연결층만 훈련한다.그렇다면 PyTorch는 어떻게 훈련할 때 밑바닥을 고정시키고 윗부분만 갱신합니까?이것은 우리가 역방향 전파로 사다리를 계산하기를 원할 때 우리는 맨 위의 권적층만 계산하기를 원할 뿐이다. 권적층에 대해 우리는 사다리를 계산하고 사다리로 파라미터를 업데이트하는 것을 원하지 않는다.우리는 네트워크의 모든 조작 대상이Variable 대상이라는 것을 알고 있다.Variable는 이 목적에 사용할 수 있는 두 가지 인자가 있다:requiresgrad와volatile.

requires_grad=False


사용자가 수동으로 Variable을 정의할 때 매개변수 requiresgrad 기본값은 False입니다.Module의 레이어를 정의할 때 관련 Variable의 Requiresgrad 매개 변수는 기본적으로 True입니다.계산도에서 입력한 Requires 가 있다면grad가 True이면 출력된 Requiresgrad도 트루야.모든 입력에 있는 Requiresgrad가 모두 False일 경우 출력된 requiresgrad는 False입니다.
>>>x = Variable(torch.randn(2, 3), requires_grad=True)
>>>y = Variable(torch.randn(2, 3), requires_grad=False)
>>>z = Variable(torch.randn(2, 3), requires_grad=False)
>>>out1 = x+y
>>>out1.requires_grad
True
>>>out2 = y+z
>>>out2.requires_grad 
False

훈련할 때 네트워크의 밑바닥을 고정시키려면 이 부분의 네트워크가 하위 그림에 대응하는 매개 변수 Requiresgrad는 False입니다.이렇게 하면 대칭 이동 중에 이러한 매개변수에 해당하는 사다리가 계산되지 않습니다.
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():#nn.Module parameters()
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)#resnet18 self.fc, 。

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)#optimizer , 

volatile=True


Variable 매개변수volatile=True 및 requiresgrad=False의 기능은 많이 떨어지지 않지만volatile의 힘은 더 크다.입력한volatile=True가 있으면 출력한volatile=True입니다.volatile=True는 모델의 추리 과정(테스트)에서 사용하는 것을 추천합니다. 이때 입력한voliate=True를 최소한의 메모리로 추리를 실행하고 중간 상태를 저장하지 않도록 합니다.
>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad # requires_grad True, Variable requires_grad True
True
>>> model(volatile_input).requires_grad# requires_grad False, volatile True( requires_grad False)
False
>>> model(volatile_input).volatile
True

좋은 웹페이지 즐겨찾기