PyTorch 신경 망 구축 및 보존 추출 방법 상세 설명

5099 단어 PyTorch신경 망
때때로 우 리 는 모델 을 훈련 시 켰 다.다음 에 직접 사용 하 기 를 바란다.다음 에 훈련 하 는 데 시간 을 들 이지 않 아 도 된다.이번 절 에 우 리 는 PyTorch 가 신경 망 을 신속하게 구축 하고 그 보존 추출 방법 에 대해 상세 하 게 설명 한다.
1.PyTorch 빠 른 신경 망 구축 방법
먼저 실험 코드 보기:

import torch 
import torch.nn.functional as F 
 
#   1,      Net         
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('  1:
', net1) # 2 torch.nn.Sequential net2 = torch.nn.Sequential( torch.nn.Linear(2, 10), torch.nn.ReLU(), torch.nn.Linear(10, 2), ) print(' 2:
', net2) # , , ''''' 1: Net ( (hidden): Linear (2 -> 10) (predict): Linear (10 -> 2) ) 2: Sequential ( (0): Linear (2 -> 10) (1): ReLU () (2): Linear (10 -> 2) ) '''
이전에 Net 류 를 정의 하여 신경 망 을 구축 하 는 방법 을 배 웠 습 니 다.classNet 에 서 는 먼저 슈퍼 함 수 를 통 해 torch.nn.Module 모듈 의 구조 방법 을 계승 한 다음 에 속성 을 추가 하 는 방식 으로 신경 망 각 층 의 구조 정 보 를 구축 하고 forward 방법 에서 신경 망 각 층 간 의 연결 정 보 를 보완 합 니 다.그리고 Net 류 대상 을 정의 하 는 방식 으로 신경 망 구조 구축 을 완성 한다.
신경 망 을 구축 하 는 또 다른 방법 은 빠 른 구축 방법 이 라 고 할 수 있다.바로 torch.N.Sequential 을 통 해 신경 망 구축 을 직접 완성 하 는 것 이다.
두 가지 방법 으로 구 축 된 신경 망 구 조 는 완전히 같 으 며,모두 print 함 수 를 통 해 네트워크 정 보 를 출력 할 수 있 으 나,인쇄 결 과 는 다소 다 를 수 있다.
2.PyTorch 의 신경 망 보존 과 추출
깊이 있 는 학습 과 연 구 를 할 때 우 리 는 일정한 시간의 훈련 을 통 해 비교적 좋 은 모델 을 얻 었 을 때 우 리 는 당연히 이 모델 과 모델 파 라 메 터 를 보존 하여 나중에 사용 하 기 를 원 하기 때문에 신경 망 의 보존 과 모델 파라미터 추출 과부하 가 필요 하 다.
우선,우 리 는 네트워크 구조 와 모델 매개 변 수 를 저장 해 야 하 는 신경 망 의 정의,훈련 부분 을 저장 한 후에 torch.save()를 통 해 네트워크 구조 와 모델 매개 변 수 를 저장 해 야 한다.두 가지 보존 방식 이 있다.하 나 는 년 전체 신경 망 의 구조 정보 와 모델 매개 변수 정 보 를 보존 하 는 것 이 고 save 의 대상 은 네트워크 net 이다.둘 째 는 신경 망 만 저장 하 는 훈련 모델 파라미터 이 고,save 의 대상 은 net.state 이다.dict(),저장 결 과 는.pkl 파일 로 저 장 됩 니 다.
위의 두 가지 저장 방식 에 대응 하여 과부하 방식 도 두 가지 가 있다.첫 번 째 완전한 네트워크 구조 정보 에 대응 하여 다시 불 러 올 때 torch.load('pkl')를 통 해 새로운 신경 망 대상 을 직접 초기 화하 면 됩 니 다.두 번 째 모델 매개 변수 정보 만 저장 하려 면 먼저 같은 신경 망 구 조 를 구축 하고 net.load 를 통 해state_dict(torch.load('.pkl')는 모델 매개 변수의 재 부팅 을 완료 합 니 다.인터넷 이 비교적 클 때 첫 번 째 방법 은 비교적 많은 시간 을 소비 할 것 이다.
코드 구현:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) #         
 
#      
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
#                   
def save(): 
  #        
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  #      
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  #      
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  #        
  torch.save(net1, '7-net.pkl')           #                  
  torch.save(net1.state_dict(), '7-net_params.pkl') #              
 
#                   
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
#             ,                       
def reload_params(): 
  #               
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  #             
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
#      
save() 
reload_net() 
reload_params() 
실험 결과:

이상 이 바로 본 고의 모든 내용 입 니 다.여러분 의 학습 에 도움 이 되 고 저 희 를 많이 응원 해 주 셨 으 면 좋 겠 습 니 다.

좋은 웹페이지 즐겨찾기