pytorch 트레이닝 데이터 및 모든 코드 테스트(7) 네트워크

1433 단어 pytorch
ASPP 네트워크
class ASPP_module(nn.Module):
    def __init__(self, inplanes, planes, rate):
        pass
    def forward(self, x):
        pass
    def _init_weight(self):
        pass

첫 번째 함수
    def __init__(self, inplanes, planes, rate):
        super(ASPP_module, self).__init__()
        if rate == 1:
            kernel_size = 1
            padding = 0
        else:
            kernel_size = 3
            padding = rate
        self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=rate, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

안의init_weight()는 다음과 같습니다. 초기화
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

다음은 forward 함수입니다.
    def forward(self, x):
        x = self.atrous_convolution(x)
        x = self.bn(x)

        return self.relu(x)

위에서 설명한 ASPP 네트워크

좋은 웹페이지 즐겨찾기