pytorch에서 모델 가져오기 input/output shape

6516 단어 심도 있는 학습
Pytorch 공식은 현재 tensorflow,caffe처럼 shape 정보를 직접 제공할 수 없습니다.
https://github.com/pytorch/pytorch/pull/3043
다음 코드는workaround의 일종으로 계산됩니다.CNN은 RNN과 같은 모듈이 다르기 때문에 다른 모듈 지원을 추가하려면 코드를 바꿔야 할 수도 있다.
예를 들어 RNN에서bias는 bool 유형이고 그 무게도 weight 속성에 저장되지 않지만 우리는 shape만 주목하면 충분하다.
이 방법은 입력 호출 forward 후 (model (x) 호출) 를 만들어야 shape를 얻을 수 있습니다
#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json


def get_output_size(summary_dict, output):
  if isinstance(output, tuple):
    for i in xrange(len(output)):
      summary_dict[i] = OrderedDict()
      summary_dict[i] = get_output_size(summary_dict[i],output[i])
  else:
    summary_dict['output_shape'] = list(output.size())
  return summary_dict

def summary(input_size, model):
  def register_hook(module):
    def hook(module, input, output):
      class_name = str(module.__class__).split('.')[-1].split("'")[0]
      module_idx = len(summary)

      m_key = '%s-%i' % (class_name, module_idx+1)
      summary[m_key] = OrderedDict()
      summary[m_key]['input_shape'] = list(input[0].size())
      summary[m_key] = get_output_size(summary[m_key], output)

      params = 0
      if hasattr(module, 'weight'):
        params += torch.prod(torch.LongTensor(list(module.weight.size())))
        if module.weight.requires_grad:
          summary[m_key]['trainable'] = True
        else:
          summary[m_key]['trainable'] = False
      #if hasattr(module, 'bias'):
      #  params +=  torch.prod(torch.LongTensor(list(module.bias.size())))

      summary[m_key]['nb_params'] = params
      
    if not isinstance(module, nn.Sequential) and \
       not isinstance(module, nn.ModuleList) and \
       not (module == model):
      hooks.append(module.register_forward_hook(hook))
  
  # check if there are multiple inputs to the network
  if isinstance(input_size[0], (list, tuple)):
    x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
  else:
    x = Variable(torch.rand(1,*input_size))

  # create properties
  summary = OrderedDict()
  hooks = []
  # register hook
  model.apply(register_hook)
  # make a forward pass
  model(x)
  # remove these hooks
  for h in hooks:
    h.remove()

  return summary

crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

pytorch 버전 CRNN의 경우 출력 shape는 다음과 같습니다.
{"Conv2d-1": {"input_shape": [1, 1, 32, 128],"output_shape": [1, 64, 32, 128],"trainable": true,"nb_params": 576},"ReLU-2": {"input_shape": [1, 64, 32, 128],"output_shape": [1, 64, 32, 128],"nb_params": 0},"MaxPool2d-3": {"input_shape": [1, 64, 32, 128],"output_shape": [1, 64, 16, 64],"nb_params": 0},"Conv2d-4": {"input_shape": [1, 64, 16, 64],"output_shape": [1, 128, 16, 64],"trainable": true,"nb_params": 73728},"ReLU-5": {"input_shape": [1, 128, 16, 64],"output_shape": [1, 128, 16, 64],"nb_params": 0},"MaxPool2d-6": {"input_shape": [1, 128, 16, 64],"output_shape": [1, 128, 8, 32],"nb_params": 0},"Conv2d-7": {"input_shape": [1, 128, 8, 32],"output_shape": [1, 256, 8, 32],"trainable": true,"nb_params": 294912},"BatchNorm2d-8": {"input_shape": [1, 256, 8, 32],"output_shape": [1, 256, 8, 32],"trainable": true,"nb_params": 256},"ReLU-9": {"input_shape": [1, 256, 8, 32],"output_shape": [1, 256, 8, 32],"nb_params": 0},"Conv2d-10": {"input_shape": [1, 256, 8, 32],"output_shape": [1, 256, 8, 32],"trainable": true,"nb_params": 589824},"ReLU-11": {"input_shape": [1, 256, 8, 32],"output_shape": [1, 256, 8, 32],"nb_params": 0},"MaxPool2d-12": {"input_shape": [1, 256, 8, 32],"output_shape": [1, 256, 4, 33],"nb_params": 0},"Conv2d-13": {"input_shape": [1, 256, 4, 33],"output_shape": [1, 512, 4, 33],"trainable": true,"nb_params": 1179648},"BatchNorm2d-14": {"input_shape": [1, 512, 4, 33],"output_shape": [1, 512, 4, 33],"trainable": true,"nb_params": 512},"ReLU-15": {"input_shape": [1, 512, 4, 33],"output_shape": [1, 512, 4, 33],"nb_params": 0},"Conv2d-16": {"input_shape": [1, 512, 4, 33],"output_shape": [1, 512, 4, 33],"trainable": true,"nb_params": 2359296},"ReLU-17": {"input_shape": [1, 512, 4, 33],"output_shape": [1, 512, 4, 33],"nb_params": 0},"MaxPool2d-18": {"input_shape": [1, 512, 4, 33],"output_shape": [1, 512, 2, 34],"nb_params": 0},"Conv2d-19": {"input_shape": [1, 512, 2, 34],"output_shape": [1, 512, 1, 33],"trainable": true,"nb_params": 1048576},"BatchNorm2d-20": {"input_shape": [1, 512, 1, 33],"output_shape": [1, 512, 1, 33],"trainable": true,"nb_params": 512},"ReLU-21": {"input_shape": [1, 512, 1, 33],"output_shape": [1, 512, 1, 33],"nb_params": 0},"LSTM-22": {"input_shape": [33, 1, 512],"0": {"output_shape": [33, 1, 512]},"1": {"0": {"output_shape": [2, 1, 256]},"1": {"output_shape": [2, 1, 256]}},"nb_params": 0},"Linear-23": {"input_shape": [33, 512],"output_shape": [33, 256],"trainable": true,"nb_params": 131072},"BidirectionalLSTM-24": {"input_shape": [33, 1, 512],"output_shape": [33, 1, 256],"nb_params": 0},"LSTM-25": {"input_shape": [33, 1, 256],"0": {"output_shape": [33, 1, 512]},"1": {"0": {"output_shape": [2, 1, 256]},"1": {"output_shape": [2, 1, 256]}},"nb_params": 0},"Linear-26": {"input_shape": [33, 512],"output_shape": [33, 3755],"trainable": true,"nb_params": 1922560},"BidirectionalLSTM-27": {"input_shape": [33, 1, 256],"output_shape": [33, 1, 3755],"nb_params": 0} }

좋은 웹페이지 즐겨찾기