chainer의 connection을 괴롭히고 새로운 층을 만든다 (1)
환경
GPU GTX1070
우분투 14.04
chainer 1.14.0
등
소개
chainer에서 최신 모델을 구현할 때는 links/connection이나 functions/connection을 괴롭힐 필요가 있다.
그래서 가장 단순한 linear.py를 만나 새로운 레이어를 만들어 보자. 지난번은 linear.py의 내용을 확인했다.
ぃ tp // 이 m / 설마 46 / ms / d66997 아 c94에 c7 아 3bcb4
이번은
chainer/functions/connection/linear.py
의 forward 함수를 만져 순전파를 개량한다.개선 모델 개요
원래의 전체 결합 3 층은 아래 그림과 같이 개선됩니다.
2층째만을 개량한다. 이 두 번째 레이어는 구체적으로 다음과 같이 입력 측에 대해 가중치를 공유합니다.
이 연산 처리는 이하의 도면과 같이된다.
W는 입력측 n개로 가중치를 공유하므로, W(out_size, in_size/n)가 된다. 이 가중치를 한 행렬 곱으로 계산할 수 있도록 in_size/n 측을 n 배하여 in_size로한다.
이 가중치와 입력측으로부터의 데이터 x의 행렬 곱을 구하면, y (batch_size, out_size)가 출력된다. 이렇게하면 가중치의 매개 변수 수가 줄어들어 계산이 빨라집니다. 그리고 성능이 약간 떨어질 것입니다. 또 이번, 계산을 간략화하기 위해 bias는 사용하지 말아 둔다.
tain_mnist.py 수정
train_mnist.py도 약간 바뀌므로 수정한다.
common_num = 10
out_units = 900
#chnged model
def __init__(self, n_in, n_units, n_out):
super(MLP, self).__init__(
l1=L.Linear(n_in, n_units), # first layer
l2=linear_link.Linear(n_units / common_num, out_units, nobias=True), # second layer
l3=L.Linear(out_units, n_out), # output layer
)
전역 변수로 정의한 common_num이 공유하는 수. 또 특히 의미는 없지만 3층째의 unit수를 900으로 바꾸었다.
function 아래 linear.py 수정
chainer/functions/connection/linear.py
의 LinearFunction 클래스 내 forward 함수를 다음과 같이 수정한다.import cupy
#modified forward function
def forward(self, inputs):
x = _as_mat(inputs[0])
W = inputs[1]
#modify to original model
W_tile = cupy.tile(W.T, (common_num, 1)).astype(W.dtype, copy=False)
y = x.dot(W_tile).astype(x.dtype, copy=False)
if len(inputs) == 3:
b = inputs[2]
y += b
return y,
W를 common_num 배배할 때 cupy(numpy)의 tile()을 사용했다. GPU 사용을 상정하고 cupy를 import하고 있지만, 사용하지 않으면 numpy로 바꿀 필요가 있다.
또 check_type_forward 함수가 있으면 에러가 나오므로, 코멘트 아웃한다.
'''
def check_type_forward(self, in_types):
n_in = in_types.size()
type_check.expect(2 <= n_in, n_in <= 3)
x_type, w_type = in_types[:2]
type_check.expect(
x_type.dtype.kind == 'f',
w_type.dtype.kind == 'f',
x_type.ndim >= 2,
w_type.ndim == 2,
type_check.prod(x_type.shape[1:]) == w_type.shape[1],
)
if n_in.eval() == 3:
b_type = in_types[2]
type_check.expect(
b_type.dtype == x_type.dtype,
b_type.ndim == 1,
b_type.shape[0] == w_type.shape[0],
)
'''
이 함수가 절개에도 x와 W의 크기가 대응하고 있는지 조사하고 있는 것 같다. 이번, 분명히 W만 작게 하고 있으므로, 이것이 기능하면 에러가 된다.
Reference
이 문제에 관하여(chainer의 connection을 괴롭히고 새로운 층을 만든다 (1)), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://qiita.com/masataka46/items/1a5d6cbd49279aeaf734텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)