PyTorch 신경 망 구축 및 보존 추출 방법 상세 설명
1.PyTorch 빠 른 신경 망 구축 방법
먼저 실험 코드 보기:
import torch
import torch.nn.functional as F
# 1, Net
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net1 = Net(2, 10, 2)
print(' 1:
', net1)
# 2 torch.nn.Sequential
net2 = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2),
)
print(' 2:
', net2)
# , ,
'''''
1:
Net (
(hidden): Linear (2 -> 10)
(predict): Linear (10 -> 2)
)
2:
Sequential (
(0): Linear (2 -> 10)
(1): ReLU ()
(2): Linear (10 -> 2)
)
'''
이전에 Net 류 를 정의 하여 신경 망 을 구축 하 는 방법 을 배 웠 습 니 다.classNet 에 서 는 먼저 슈퍼 함 수 를 통 해 torch.nn.Module 모듈 의 구조 방법 을 계승 한 다음 에 속성 을 추가 하 는 방식 으로 신경 망 각 층 의 구조 정 보 를 구축 하고 forward 방법 에서 신경 망 각 층 간 의 연결 정 보 를 보완 합 니 다.그리고 Net 류 대상 을 정의 하 는 방식 으로 신경 망 구조 구축 을 완성 한다.신경 망 을 구축 하 는 또 다른 방법 은 빠 른 구축 방법 이 라 고 할 수 있다.바로 torch.N.Sequential 을 통 해 신경 망 구축 을 직접 완성 하 는 것 이다.
두 가지 방법 으로 구 축 된 신경 망 구 조 는 완전히 같 으 며,모두 print 함 수 를 통 해 네트워크 정 보 를 출력 할 수 있 으 나,인쇄 결 과 는 다소 다 를 수 있다.
2.PyTorch 의 신경 망 보존 과 추출
깊이 있 는 학습 과 연 구 를 할 때 우 리 는 일정한 시간의 훈련 을 통 해 비교적 좋 은 모델 을 얻 었 을 때 우 리 는 당연히 이 모델 과 모델 파 라 메 터 를 보존 하여 나중에 사용 하 기 를 원 하기 때문에 신경 망 의 보존 과 모델 파라미터 추출 과부하 가 필요 하 다.
우선,우 리 는 네트워크 구조 와 모델 매개 변 수 를 저장 해 야 하 는 신경 망 의 정의,훈련 부분 을 저장 한 후에 torch.save()를 통 해 네트워크 구조 와 모델 매개 변 수 를 저장 해 야 한다.두 가지 보존 방식 이 있다.하 나 는 년 전체 신경 망 의 구조 정보 와 모델 매개 변수 정 보 를 보존 하 는 것 이 고 save 의 대상 은 네트워크 net 이다.둘 째 는 신경 망 만 저장 하 는 훈련 모델 파라미터 이 고,save 의 대상 은 net.state 이다.dict(),저장 결 과 는.pkl 파일 로 저 장 됩 니 다.
위의 두 가지 저장 방식 에 대응 하여 과부하 방식 도 두 가지 가 있다.첫 번 째 완전한 네트워크 구조 정보 에 대응 하여 다시 불 러 올 때 torch.load('pkl')를 통 해 새로운 신경 망 대상 을 직접 초기 화하 면 됩 니 다.두 번 째 모델 매개 변수 정보 만 저장 하려 면 먼저 같은 신경 망 구 조 를 구축 하고 net.load 를 통 해state_dict(torch.load('.pkl')는 모델 매개 변수의 재 부팅 을 완료 합 니 다.인터넷 이 비교적 클 때 첫 번 째 방법 은 비교적 많은 시간 을 소비 할 것 이다.
코드 구현:
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
torch.manual_seed(1) #
#
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
#
def save():
#
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_function = torch.nn.MSELoss()
#
for i in range(300):
prediction = net1(x)
loss = loss_function(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#
torch.save(net1, '7-net.pkl') #
torch.save(net1.state_dict(), '7-net_params.pkl') #
#
def reload_net():
net2 = torch.load('7-net.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# ,
def reload_params():
#
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
#
net3.load_state_dict(torch.load('7-net_params.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#
save()
reload_net()
reload_params()
실험 결과:이상 이 바로 본 고의 모든 내용 입 니 다.여러분 의 학습 에 도움 이 되 고 저 희 를 많이 응원 해 주 셨 으 면 좋 겠 습 니 다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
IceVision에서 형식별 데이터를 읽는 방법2021년에 가장 멋있는 물체 검출 프레임워크라고 해도 과언이 아닌 IceVision을 사용해, VOC format과 COCO format의 데이터 세트에 대해 Object Detection을 간단하게 실시하기 위한...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.