PyTorch에 SSIM Looss 설치

  • 이미지 X 및 이미지 Y에서 부분 영역 x 및 y 잘라내기
  • 로컬 영역 x 및 y 내 픽셀 값에 따라 평균\mux와\muy, 표준 편차\sigmax 및\sigmay, 협방차\sigma계산
  • 식(1) 컴퓨팅 스테이션 영역의 SSIM
  • SSIM =\frac{(2\mu_x\mu_y + c_1)(2\mu_{xy} + c_2)}{(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2\sigma_y^2 + c_2)}\quad\quad (1)
  • 픽셀별로 xy 방향으로 부분 영역을 슬라이드하여 SSIM을 다시 계산합니다.이미지 크기 256×256, 로컬 영역 크기 64×64의 경우 (256-64+1)\times(256-64+1)=37249회 SSIM 계산(padding 0)이 필요합니다.
  • PyTorch를 사용한 Conv2d 설치

  • 로컬 영역 크기와 같은kerner를 준비하고 PyTorch Conv2d를 사용하여 평균, 분산, 합방차를 계산합니다.
  • 실제 응용에서 이미지 X와 이미지 Y가 매끄러워진 후에 SSIM을 계산하기 때문에 uniform kernel이 아닌gaussian kerel을 사용합니다.
  • 표준 편차는\sigma^2=\overline{x^2}-(\overlinex)^2를 통해 계산한다.다음 장에서 공식을 내보냅니다.
  • 상수는 c1 = (k_1 L)^2、c_2=(k 2L)^2처럼 정의하면 L은 동적 범위(8비트의 경우 255), k1과 k2는 하이퍼매개변수이며 기본값은 0.01과 0.03입니다.
  • import torch
    import torch.nn.functional as F
    from torch import Tensor
    from torch.nn import Module
    
    
    class SSIMLoss(Module):
        def __init__(self, kernel_size: int = 11, sigma: float = 1.5) -> None:
    
            """Computes the structural similarity (SSIM) index map between two images.
    
            Args:
                kernel_size (int): Height and width of the gaussian kernel.
                sigma (float): Gaussian standard deviation in the x and y direction.
            """
    
            super().__init__()
            self.kernel_size = kernel_size
            self.sigma = sigma
            self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
    
        def forward(self, x: Tensor, y: Tensor, as_loss: bool = True) -> Tensor:
    
            if not self.gaussian_kernel.is_cuda:
                self.gaussian_kernel = self.gaussian_kernel.to(x.device)
    
            ssim_map = self._ssim(x, y)
    
            if as_loss:
                return 1 - ssim_map.mean()
            else:
                return ssim_map
    
        def _ssim(self, x: Tensor, y: Tensor) -> Tensor:
    
            # Compute means
            ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
            uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
    
            # Compute variances
            uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
            uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
            uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
            vx = uxx - ux * ux
            vy = uyy - uy * uy
            vxy = uxy - ux * uy
    
            c1 = 0.01 ** 2
            c2 = 0.03 ** 2
            numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
            denominator = (ux ** 2 + uy ** 2 + c1) * (vx + vy + c2)
            return numerator / (denominator + 1e-12)
    
        def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> Tensor:
    
            start = (1 - kernel_size) / 2
            end = (1 + kernel_size) / 2
            kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
            kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
            kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
    
            kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
            kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
            return kernel_2d
    

    분산 공식의 변형


    \begin{aligned}
    \sigma^2 &=\frac{(x_1 -\overline x)^2 + (x_2 -\overline x)^2 +\dots + (x_n -\overline x)^2}{n}\\\\
    &=\frac{x_1^2 -2x_1\overline x + (\overline x) ^2 + x_2^2 -2x_2\overline x + (\overline x) ^2 +\dots + x_n^2 -2x_n\overline x + (\overline x) ^2}{n}\\\\
    &=\frac{(x_1^2 + x_2^2 +\dots + x_n^2) - 2\overline x (x_1 + x_2 +\dots + x_n) + n (\overline x)^2}{n}\\\\
    &=\overline {x^2} - 2 (\overline x)^2 + (\overline x)^2\\\\
    &=\overline {x^2} - (\overline x)^2
    \end{aligned}

    References

  • https://github.com/kornia/kornia/blob/master/kornia/losses/ssim.py
  • https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/
  • https://github.com/scikit-image/scikit-image/blob/master/skimage/metrics/_structural_similarity.py
  • https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/functional/ssim.py
  • https://github.com/pytorch/ignite/blob/master/ignite/metrics/ssim.py
  • 좋은 웹페이지 즐겨찾기