Unet 분할 모델 - pytorch 코드

import torch
import torch.nn as nn


class pub(nn.Module):
    def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
        super(pub, self).__init__()
        pad = 1 if keep_size else 0
        Layer = [
                 nn.Conv2d(in_channel, out_channel, 3, padding=pad),
                 nn.ReLU(True),
                 nn.Conv2d(out_channel, out_channel, 3, padding=pad),
                 nn.ReLU(True)
                ]
        if batch_norm:
            Layer.insert(1, nn.BatchNorm2d(out_channel))
            Layer.insert(len(Layer) - 1, nn.BatchNorm2d(out_channel))
        self.pub_con = nn.Sequential(*Layer)

    def forward(self, x):
        return self.pub_con(x)


class unet_down(nn.Module):

    def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
        super(unet_down, self).__init__()
        self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(x)
        x = self.pub(x)
        return x


class unet_up(nn.Module):

    def __init__(self, in_channel, out_channel, batch_norm=True, upsample=True, keep_size=False):
        super(unet_up, self).__init__()
        layers = []
        if upsample:
            layers += [nn.Conv2d(out_channel*2, out_channel, 1)]
            layers += [nn.Upsample(scale_factor=2, mode='nearest')]
        else:
            layers += [nn.ConvTranspose2d(out_channel*2, out_channel, 2, stride=2)]
        self.upsample = nn.Sequential(*layers)
        self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
        self.orignal_size = keep_size

    def forward(self, x1, x2):
        x2 = self.upsample(x2)
        c = (x1.size(2) - x2.size(2)) // 2
        x1 = x1[:, :, c:-c, c:-c]
        x = torch.cat((x1, x2), 1)
        x = self.pub(x)
        return x


class Unet(nn.Module):
    def __init__(self, channels, class_nums, layers=5, upsample=True, batch_norm=True, keep_size=False):
        super(Unet, self).__init__()
        self.layers = layers
        down = []
        up = []
        down.append(pub(channels, 64, batch_norm, keep_size))
        for layer in range(layers-1):
            down.append(unet_down(64*(2**layer), 128*(2**layer), batch_norm, keep_size))
            up.append(unet_up(128*(2**(3-layer)), 64*(2**(3-layer)), upsample, batch_norm, keep_size))
        up.append(nn.Conv2d(64, class_nums, 1))
        self.down = nn.ModuleList(down)
        self.up = nn.ModuleList(up)
        self._initialize_weights()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        down = []
        for i in range(self.layers):
            x = self.down[i](x)
            down.append(x)
        x = down[self.layers-1]
        for j in range(self.layers-1):
            x = self.up[j](down[self.layers-j-2], x)
        x = self.up[4](x)
        return self.sigmoid(x)

    def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_uniform(m.weight.data)
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

좋은 웹페이지 즐겨찾기