chainer의 connection을 괴롭히고 새로운 층을 만든다 (1)

환경



GPU GTX1070
우분투 14.04
chainer 1.14.0


소개



chainer에서 최신 모델을 구현할 때는 links/connection이나 functions/connection을 괴롭힐 필요가 있다.

그래서 가장 단순한 linear.py를 만나 새로운 레이어를 만들어 보자. 지난번은 linear.py의 내용을 확인했다.
ぃ tp // 이 m / 설마 46 / ms / d66997 아 c94에 c7 아 3bcb4

이번은 chainer/functions/connection/linear.py 의 forward 함수를 만져 순전파를 개량한다.

개선 모델 개요



원래의 전체 결합 3 층은 아래 그림과 같이 개선됩니다.

2층째만을 개량한다. 이 두 번째 레이어는 구체적으로 다음과 같이 입력 측에 대해 가중치를 공유합니다.

이 연산 처리는 이하의 도면과 같이된다.

W는 입력측 n개로 가중치를 공유하므로, W(out_size, in_size/n)가 된다. 이 가중치를 한 행렬 곱으로 계산할 수 있도록 in_size/n 측을 n 배하여 in_size로한다.

이 가중치와 입력측으로부터의 데이터 x의 행렬 곱을 구하면, y (batch_size, out_size)가 출력된다. 이렇게하면 가중치의 매개 변수 수가 줄어들어 계산이 빨라집니다. 그리고 성능이 약간 떨어질 것입니다. 또 이번, 계산을 간략화하기 위해 bias는 사용하지 말아 둔다.

tain_mnist.py 수정



train_mnist.py도 약간 바뀌므로 수정한다.
common_num = 10
out_units = 900
    #chnged model
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),  # first layer
            l2=linear_link.Linear(n_units / common_num, out_units, nobias=True),  # second layer
            l3=L.Linear(out_units, n_out),  # output layer
        )

전역 변수로 정의한 common_num이 공유하는 수. 또 특히 의미는 없지만 3층째의 unit수를 900으로 바꾸었다.

function 아래 linear.py 수정


chainer/functions/connection/linear.py의 LinearFunction 클래스 내 forward 함수를 다음과 같이 수정한다.
import cupy
    #modified forward function
    def forward(self, inputs):
        x = _as_mat(inputs[0])
        W = inputs[1]

        #modify to original model
        W_tile = cupy.tile(W.T, (common_num, 1)).astype(W.dtype, copy=False)
        y = x.dot(W_tile).astype(x.dtype, copy=False)

        if len(inputs) == 3:
            b = inputs[2]
            y += b

        return y,


W를 common_num 배배할 때 cupy(numpy)의 tile()을 사용했다. GPU 사용을 상정하고 cupy를 import하고 있지만, 사용하지 않으면 numpy로 바꿀 필요가 있다.

또 check_type_forward 함수가 있으면 에러가 나오므로, 코멘트 아웃한다.
    '''
    def check_type_forward(self, in_types):
        n_in = in_types.size()
        type_check.expect(2 <= n_in, n_in <= 3)
        x_type, w_type = in_types[:2]

        type_check.expect(
            x_type.dtype.kind == 'f',
            w_type.dtype.kind == 'f',
            x_type.ndim >= 2,
            w_type.ndim == 2,
            type_check.prod(x_type.shape[1:]) == w_type.shape[1],
        )
        if n_in.eval() == 3:
            b_type = in_types[2]
            type_check.expect(
                b_type.dtype == x_type.dtype,
                b_type.ndim == 1,
                b_type.shape[0] == w_type.shape[0],
            )
    '''

이 함수가 절개에도 x와 W의 크기가 대응하고 있는지 조사하고 있는 것 같다. 이번, 분명히 W만 작게 하고 있으므로, 이것이 기능하면 에러가 된다.

좋은 웹페이지 즐겨찾기