Pytorch 모델 pth 파일에서 매개 변수를numpy 행렬로 읽는 작업

1341 단어 Pytorchpthnumpy행렬

목적:


훈련된 pth모델 파라미터를 추출하여 다른 방식으로 가장자리 장치에 배치합니다.

Pytorch는 읽기 쉬운 매개변수 인터페이스를 제공합니다.


nn.Module.parameters()

데모를 직접 보기:


from torchvision.models.alexnet import alexnet 
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)
위에서 얻은 numpy_파라가 바로 numpy 파라미터입니다.
Note:
model.parameters () 는 생성기 형식으로 모든 층을 반복해서 되돌려주는 매개 변수입니다.그래서 for 순환으로 각 층의 매개 변수를 읽으면 순환 횟수는 층수를 나타낸다.
모든 층의 매개 변수는torch이다.nn.parameter.Parameter 형식은 Tensor의 하위 클래스이기 때문에 직접tensor로numpy(즉 p.detach()를 전환합니다.cpu().numpy () 방법은 직접numpy 행렬로 전환할 수 있습니다.
편하고 좋아요, 대박~
보충:pytorch 훈련된.pth 모델을 로 변환합니다.pt
파이썬을 훈련시켰어요.pth 파일이 로 전환됩니다.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)# 
model.load_state_dict(torch.load("best_weights.pth"))# 
model.eval()# eval()
example = torch.rand(1, 3, 320, 480)# 
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
이상의 개인적인 경험으로 여러분께 참고가 되었으면 좋겠습니다. 또한 많은 응원 부탁드립니다.만약 잘못이 있거나 완전한 부분을 고려하지 않으신다면 아낌없이 가르침을 주시기 바랍니다.

좋은 웹페이지 즐겨찾기