기계 학습에서 자주 나오는 함수 오버플로 방지

소개



시그모이드나 소프트 플러스 등의 지수함수를 포함한 함수를 스스로 구현하면 지수함수부에서 오버플로우가 되어 올바르게 계산하지 못할 수 있습니다.

함수에 따라서는 간단한 궁리로 막을 수 있으므로 소개합니다.

시그모이드 함수



시그모이드 함수는
f(x) = \frac{1}{1+\exp(-x)}

로 표시되고 $(-1,1)$의 값 영역을 취합니다.

(이미지는 Wikipedia 보다 인용)

$x$가 마이너스 방향으로 너무 가면 이론적으로는 $f(x)\simeq 0$이지만 $\exp(-x)$가 오버플로우합니다.

$ x < 0 $ 일 때 분모 분자에 $\exp (x) $를 곱하여 변형합니다.
f(x)=
\begin{cases}
\displaystyle\frac{1}{1+\exp(-x)} & \text{if $x\geq 0$} \\
\displaystyle\frac{\exp(x)}{\exp(x)+1} & \text{if $x< 0$}
\end{cases}

그리고 경우 나누면 항상 올바르게 계산할 수 있습니다.
또한 사례를 정리하면
f(x) = \frac{\exp(\min(0,x))}{1+\exp(-|x|)}

수 있습니다.

소프트 플러스 함수



소프트 플러스 함수는
f(x) = \log \left( 1+\exp(x) \right)

로 표시됩니다.

(이미지는 Wikipedia 보다 인용)

이 함수도 $x$가 너무 커지면 $\exp(x)$가 오버플로됩니다.

$x\geq 0$일 때는 $\log$의 내용을 $\exp(x)$로 묶고 $\log\exp(x)=x$인 것을 이용하여 $\exp(x) $를 지웁니다.
\begin{align}
f(x) &= \log\left\{ \exp(x)(\exp(-x)+1) \right\} \\
&= \log\exp(x) + \log(\exp(-x)+1) \\
&= x+\log(\exp(-x)+1)
\end{align}

즉,
f(x) = 
\begin{cases}
\log(1+\exp(x)) & \text{if $x<0$} \\
x+\log(\exp(-x)+1) & \text{if $x\geq 0$}
\end{cases}

을 계산하면 좋을 것입니다.

또, 경우 분할을 잘 정리하면
f(x) = \max (0,x) + \log(1+\exp(-|x|))

라고 쓸 수 있습니다.

소프트 맥스 함수



소프트 맥스 함수는 위와 달리 벡터를 받고 벡터를 반환합니다.
입력을 ${\bf x} = [x_1,\ldots, x_N]$
{\bf f}({\bf x}) = \left[ \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right]_{i=1}^N

로 표시됩니다.

입력 ${\bf x}$ 중 가장 큰 값을 가진 요소
x^{\rm m}=\max(x_1,\ldots,x_N)

그런 다음 분모 분자를 $\exp(x^{\rm m})$로 나눈
{\bf f}({\bf x}) = \left[ \frac{\exp(x_i-x^{\rm m})}{\sum_j \exp(x_j - x^{\rm m})} \right]_{i=1}^N

를 계산하여 오버플로를 피할 수 있습니다.

결론



모두 지수 함수의 인수가 0 이하가 되는 것이 포인트입니다.

좋은 웹페이지 즐겨찾기