[Pytorch]: 네트워크 레이어 사용자 정의

6006 단어 Caffe

사용자 정의 후방 전파


forward () 네트워크 층의 계산 조작은 당신의 필요에 따라 파라미터를 설정할 수 있습니다.backward () 계단식 계산 작업.

Function 상속

class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_variables
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

몇 개의 입력이 있으면 해당 입력의 사다리 출력, 즉grad_input. 권한 값의 사다리는 반드시 되돌아와야 한다. 되돌아오는 순서는 입력과 일치해야 한다. 즉grad_input_n、···、grad_input_n,grad_weight 및grad_bias.

계단식 테스트

from torch.autograd import gradcheck

input = (Variable(torch.randn(20,20).double(), requires_grad=True), Variable(torch.randn(30,20).double(), requires_grad=True),)
test = gradcheck(Linear.apply, input, eps=1e-6, atol=1e-4)
print(test)

모듈 추가


토치를 넓히다.nn에서 새 Module을 생성해야 합니다.
class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter  Module , .register_buffer() , 。
        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            #  , 。
            self.register_parameter('bias', None)

        #  
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        #  
        return LinearFunction.apply(input, self.weight, self.bias)

좋은 웹페이지 즐겨찾기