PRML4장 해설과 구현

PRML 학습기



이번 「패턴 인식과 기계 학습」4장의 윤강 발표 담당이 되었으므로, 공부한 일이나 약간의 해설 등을 써 가고 싶다. 자신도 이 책에 고전한 사람 중 하나이므로 앞으로 비슷한 처지 사람이 있을 때 도움이 되면 매우 기쁩니다. 만약 수리적인 오류 등을 찾아내거나, 좀더 이러한 쪽이 좋다고 하는 지적이 있으면 삼가해 주시면 도움이 됩니다.

피셔의 선형 판별



2 클래스



식별 함수의 항은 최소 제곱부터 시작되고 있지만 원래 최소 제곱은 "잘 사용할 수 없는 것은 당연하다"는 결론이므로 할애. 그래서 2 클래스 피셔에서. 여기서는 선형 식별을 차원 삭감의 관점에서 본다.

입력으로서 D 차원 벡터를 얻고, 다음 식으로 1 차원으로 투영
y = \boldsymbol{w}^T\boldsymbol{x}

$y$에 임계값을 설정하여 $y\ge -w_0$ 일 때 클래스 $C_1$로 분류하지 않을 때는 $C_2$로 분류한다. 차원을 떨어뜨린 분 정보의 손실이 발생하기 때문에 $\boldsymbol{w}$를 조정해 클래스의 분리를 최대로 해 나가고 싶다.

여기서 클래스 $C_1$의 점이 $N_1$개이고 $C_2$의 점이 $N_2$개 있다면 각 클래스의 평균 벡터는
\boldsymbol{m}_1 = \frac{1}{N_1}\sum_{n \in C_1}\boldsymbol{x}_n, \quad
\boldsymbol{m}_2 = \frac{1}{N_2}\sum_{n \in C_2}\boldsymbol{x}_n

이때, 「클래스의 평균끼리가 가장 멀어지는 곳에 투영하자」라고 하는 생각에 근거해, 이하의 식을 최대로 하는 $\boldsymbol{w}$를 선택
m_2 - m_1 = \boldsymbol{w}^T(\boldsymbol{m}_2 - \boldsymbol{m}_1)


여기서, $ m_k $는 $ C_k $로부터 투영 된 데이터의 평균을 나타낸다. $\boldsymbol{w}$를 얼마든지 크게 할 수 있다면 의미가 없기 때문에 놈=1이라는 제약을 더한다. 이른바 라그랑주의 미정승수법의 차례. 벡터의 미분의 기초를 알고 있으면 아무 문제도 없다.
L = \boldsymbol{w}^T(\boldsymbol{m}_2 - \boldsymbol{m}_1) + \lambda(\boldsymbol{w}^T\boldsymbol{w}-1)\\
\\
\nabla L=(\boldsymbol{m}_2 - \boldsymbol{m}_1)+2\lambda\boldsymbol{w}\\
\\
\boldsymbol{w}=-\frac{1}{2\lambda}(\boldsymbol{m}_2 - \boldsymbol{m}_1)\propto(\boldsymbol{m}_2 - \boldsymbol{m}_1)

다만 실제 이것으로는 아직 클래스끼리가 겹쳐 버리는 경우가 있다. 그래서 "사영 후에 같은 클래스는 정리되어 있고, 클래스는 서로 떨어져있다"같은 방법을 취하고 싶다. 그래서 피셔의 판별 기준을 도입. 각 클래스의 클래스 내 분산은
s_k^2 = \sum_{n \in C_k}(y_k - m_k)^2

따라서 판별 기준은 다음과 같습니다.
J(\boldsymbol{w}) = \frac{(m_2-m_1)^2}{s_1^2 + s_2^2}

분모는 총 클래스내 분산으로, 각 클래스의 분산의 합으로 정의. 분자는 클래스간 분산. 본 절에서는 이것을 다음과 같이 재작성하고 있다.
J(\boldsymbol{w}) = \frac{\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}}{\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w}}

여기에서
\boldsymbol{S}_\boldsymbol{B} = (\boldsymbol{m}_2 - \boldsymbol{m}_1)(\boldsymbol{m}_2 - \boldsymbol{m}_1)^T\\
\\
\boldsymbol{S}_\boldsymbol{W} =\sum_{k}\sum_{n\in C_k}(\boldsymbol{x}_n-m_k)(\boldsymbol{x}_n-m_k)
^T

전자는 클래스 간 공분산 행렬, 후자는 총 클래스 내 공분산 행렬이라고 불린다. 자신에게는 조금 설치하기 어려운 외형을 하고 있어서 당황했지만, 분모도 분자도 $y=\boldsymbol{w}^T\boldsymbol{x}$인 것을 이용해 전개해 보면 원래의 수식과 같음을 알 수 있습니다.

따라서, J(w)를 w에 관해서 미분하여 제로로 하는 것으로, J가 최대가 되는 w를 구할 수 있다.

\frac{\partial J}{\partial w}=\frac{(2(\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})-2(\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}))}{(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})^2}=0\\
\\\\
(\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}) = (\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})

$\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w}$가 스칼라이고 2차 형식을 미분할 때 공분산 행렬이 대칭 행렬임 이용하고 있는 것이 포인트입니다. 이에 대해서는 이제 다른 기사에 씁니다.

앞에서와 같이 이번에도 중요한 것은 $\boldsymbol{w}$의 방향이므로 크기가 아니므로 $\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}$
\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w} = (\boldsymbol{m}_2 - \boldsymbol{m}_1)(\boldsymbol{m}_2 - \boldsymbol{m}_1)^T\boldsymbol{w}

보다 $(\boldsymbol{m}_2 -\boldsymbol{m}_1)$와 같은 방향의 벡터인 것을 이용해
\boldsymbol{w} \propto \boldsymbol{S}_\boldsymbol{W}^-1(\boldsymbol{m}_2 - \boldsymbol{m}_1)

이것으로 w의 방향이 정해졌으므로 버려!

번외편:코드로 해 보았습니다



fisher_2d.py
# Class 1
mu1 = [5, 5]
sigma = np.eye(2, 2)
c_1 = np.random.multivariate_normal(mu1, sigma, 100).T

# Class 2
mu2 = [0, 0]
c_2 = np.random.multivariate_normal(mu2, sigma, 100).T

# Average vectors
m_1 = np.sum(c_1, axis=1, keepdims=True) / 100.
m_2 = np.sum(c_2, axis=1, keepdims=True) / 100.

# within-class covariance matrix 
S_W = np.dot((c_1 - m_1), (c_1 - m_1).T) + np.dot((c_2 - m_2), (c_2 - m_2).T)

w = np.dot(np.linalg.inv(S_W), (m_2 - m_1))
w = w/np.linalg.norm(w)

plt.quiver(4, 2, w[1, 0], -w[0, 0], angles="xy", units="xy", color="black", scale=0.5)
plt.scatter(c_1[0, :], c_1[1, :])
plt.scatter(c_2[0, :], c_2[1, :])

quiver를 사용하여 구한 벡터를 플롯한 결과가 이쪽



방향은 좋은 느낌이군요. 그렇다고 해서 다음번은 다클래스판을 기사로 하려고 합니다.

좋은 웹페이지 즐겨찾기