WideResNet 작성시에 걸린 점

소개



본 기사의 대상자
・WideResNet의 정밀도 재현에 고생하고 있는 분
· ResNet의 기본 구조를 이해하는 분

대상이 아닌 사람
· WideResNet의 논문의 요약을 읽고 싶은 분
· WideResNet의 구조의 개요를 알고 싶은 분

version 등은 이하의 github의 README에 명기하고 있는 바와 같이 python3.7와 자신의 Cuda의 version에 있던 pytorch이다.

Pytorch로 WideResNet 구현
WideResNet의 전 논문

코드 설명 및 주의점



최고 정밀도



이번에는 비교적 계산 시간이 짧고 정밀도가 좋은 WRN-28-10의 모델에서 CIFAR100의 식별 정밀도의 검증을 실시했다.
4 회 실행에서 최고의 테스트 정확도는 81.6 %였다.

데이터 전처리



논문에서 볼 수 있듯이 정규화, Random Crop 및 Horizontal Flip이 수행됩니다. 또한 가장자리에 대해서는 원래 코드에서 Reflect를 수행하기 때문에이를 재현합니다.
def get_data(batch_size):

    normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
                                     std=[0.2471, 0.2435, 0.2616])

    transform_train = transforms.Compose([transforms.Pad(4, padding_mode = 'reflect'),
                                          transforms.RandomCrop(32),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(), 
                                          normalize])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    train_dataset = datasets.CIFAR100(root="cifar",
                                      train=True, 
                                      download=True,
                                      transform=transform_train)
    test_dataset = datasets.CIFAR100(root="cifar",
                                     train=False, 
                                     download=False,
                                     transform=transform_test)

    train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True)
    test_data = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)


기본 블록



이 실험에서는 (d) wide-dropout를 사용했다. 또한 conv-BN-ReLU보다 BN-ReLU-conv는 원문에 더 빠르고 정확도가 좋다고보고되었으므로 구조를 사용했습니다. 컨벌루션에서 Bias 항은 넣지 않습니다.

또한 pytorch에서주의가 필요한 것은 Dropout입니다. Dropout2dDropout은 완전히 별개의 함수입니다. 전자는 확률 p로 선택된 커널의 요소가 모두 0이되는 반면, 후자는 입력 텐서의 요소가 확률 p에서 0이된다.

또한 BN과 Dropout의 순서에주의가 필요하며이 순서를 반대로해도 정밀도가 떨어집니다.

위의주의 사항을 잘못하면 각각 1 % 정도 정도 정밀도가 악화된다.


class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, drop_rate=0.3, kernel_size=3):
        super(BasicBlock, self).__init__()

        self.in_is_out = (in_ch == out_ch and stride == 1)
        self.drop_rate = drop_rate
        self.shortcut = nn.Sequential() if self.in_is_out else nn.Conv2d(in_ch, out_ch, 1, padding=0, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.c1 = nn.Conv2d(in_ch, out_ch, kernel_size, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.c2 = nn.Conv2d(out_ch, out_ch, kernel_size, padding=1, bias=False)

      def forward(self, x): 
          h = F.relu(self.bn1(x), inplace=True) 
          h = self.c1(h) 
          h = F.relu(self.bn2(h), inplace=True)
          h = F.dropout(h, p=self.drop_rate, training=self.training)
          h = self.c2(h)

          return h + self.shortcut(x)

가중치 초기화



컨벌루션 레이어의 가중치 초기화와 관련하여. Default의 초기화 함수는 mode = 'fan_in', 즉 입력의 크기에 따라 초기화가 수행됩니다. 반면에 WideResNet은 출력 크기를 참조하여 초기화를 수행하므로 'fan_out'에 의한 초기화가 바람직합니다.

이를 위해 kaiming_normal 참조를 둡니다. He의 정규 분포 참조.
for m in self.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out') 
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.bias, 0.0) 
        nn.init.constant_(m.weight, 1.0)
    elif isinstance(m, nn.Linear): 
        nn.init.constant_(m.bias, 0.0)

좋은 웹페이지 즐겨찾기