심층 학습/행렬 곱의 오차 역전파

1. 소개



행렬적의 오차 역전파가 알기 어려웠으므로, 정리해 둔다.

2. 스칼라 제품의 오차 역전파



스칼라 곱의 오차 역전파로부터 복습하면,

경사를 실시하는 대상을 L로 하고, 미리 $\frac{\partial L}{\partial y}$를 알고 있으면 연쇄율로부터

이건 문제 없네요.

3. 행렬 곱의 오차 역 전파



그런데, 행렬적이 되면 직감과 바뀌어 옵니다.


왠지, 핀과 오지 않지요. 그래서 구체적으로 확인합니다.
설정은 2개의 뉴런 X와 4개의 가중치 W의 내적을 거쳐 뉴런 Y에 접속되어 있다고 생각합니다.


1) 먼저 $\frac{\partial L}{\partial X}$를 찾습니다. 우선, 이것들을 미리 계산해 둡니다.

이 계산을 도중에 이용하면서,


2) 그런 다음 $\frac{\partial L}{\partial y}$를 찾습니다. 우선, 이것들을 미리 계산해 둡니다.

이 계산을 도중에 이용하면서,


4. 행렬 곱 순서 전파, 역 전파 코드



x1=X, x2=Y, grad=$\frac{\partial L}{\partial y}$
class MatMul(object):
    def __init__(self, x1, x2):
        self.x1 = x1
        self.x2 = x2

    def forward(self):
        y = np.dot(self.x1, self.x2)
        self.y = y
        return y

    def backward(self, grad):
        grad_x1 = np.dot(grad, self.x2.T)
        grad_x2 = np.dot(self.x1.T, grad)
        return (grad_x1, grad_x2)

좋은 웹페이지 즐겨찾기