BNN Pytorch 코드 읽기 노트

BNN Pytorch 코드 읽기 노트


이 블로그는 제가 BNN(이치화신경망)pytorch 코드에 대한 이해를 쓰도록 하겠습니다. 저는 프로젝트 코드를 처음 읽기 때문에 꼼꼼하게 직접 써서 디테일을 철저하게 이해하고 여러분께 도움이 되었으면 좋겠습니다.
논문 링크:https://papers.nips.cc/paper/6573-binarized-neural-networks
코드 링크:https://github.com/itayhubara/BinaryNet.pytorch

1. 프로젝트 구조:


models 네트워크 구조 구축 스크립트 집합
init.py 초기화 스크립트
alexnet.py alexnet pytorch 버전 구현
alexnet_binary.py가alexnet에 대한 이치화 코드 구현
binarized_modules.py 양적 함수 실현
resnet.py resnet pytorch 버전 구현
resnet_binary.py가resnet에 대한 이치화 코드 실현
vgg_cifar10.py vggnet pytorch 버전 구현
vgg_cifar10_binary.py vggnet에 대한 이치화 코드 구현
data.py 데이터 읽기 스크립트
main_binary.py 트레이닝 + 테스트 스크립트
main_binary_hinge.py훈련+테스트+hingeloss 스크립트
main_mnist.py MNIST 데이터 세트 트레이닝 + 테스트 스크립트
preprocess.py 데이터 사전 처리 관련 스크립트
utils.py 매개 변수 기록 로그 스크립트

2. 코드 판독


2.1 main_binary.py & data.py & utils.py & preprocess.py


먼저 도입 모듈:
import argparse
import os
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from torch.autograd import Variable
from data import get_dataset
from preprocess import get_transform
from utils import *
from datetime import datetime
from ast import literal_eval
from torchvision.utils import save_image

이 ARgparse 패키지, logging 패키지,ast 패키지는 이전에 접촉한 적이 없습니다. Google은 다음 코드와 직접 결합해서 말합니다.
Argparse는python 표준 라이브러리에서 명령행 파라미터를 처리하는 라이브러리입니다. 말하자면 명령행을 쓰는 라이브러리입니다. 프로젝트 공학 코드에서 이전에 코드를 쓴 것처럼 버튼 하나로 코드를 완성할 수 없습니다. 필요한 명령행 코드를 배우는 것은 필요합니다. 코드를 결합시켜 분석합시다.
model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

이 단계는 모델 폴더에 있는 스크립트 파일을 정렬해서 모델에 봉인하는 것입니다names에서 뒤에 있는 매개 변수를 추가하는 데 사용합니다.
parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')

분명히 이 구절은argparse를 초기화하는 대상이다.
parser.add_argument
parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results',
                    help='results dir')
...
parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
                    help='evaluate model FILE on validation set')

이 단락은 대량의 중복 형식 코드,parser.add_argument:
ArgumentParser.add_argument(name or flags…[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest]) 

name of flags는 필수 매개 변수입니다. 이 매개 변수는 옵션 매개 변수나 위치 매개 변수를 수락합니다.예를 들어 위의 "–results"dir", 시작 프로그램mainbinary.py시,./main_binary.py --results_dir xxx, xxx를resultsdir, 다음에 중복된 것은 군더더기 없이 매개 변수가 없을 때default에서 값을 가져옵니다.
def main():
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

만약 매개 변수에evaluate 항목이 있다면 결과 디렉터리 앞에/tmp를 붙여서 임시 저장을 표시합니다. 이 코드는 훈련 결과를 저장할 수 있는 경로입니다.
logging 패키지
logging 모듈은 Python에 내장된 표준 모듈로 주로 출력 실행 로그에 사용되며 출력 로그의 등급, 로그 저장 경로, 로그 파일 스크롤 등을 설정할 수 있다.print에 비해 다음과 같은 이점이 있습니다.
1). 서로 다른 로그 등급을 설정하여release 버전에서 중요한 정보만 출력할 수 있고 대량의 디버깅 정보를 표시할 필요가 없다.
2). print는 모든 정보를 표준 출력에 출력하여 개발자가 표준 출력에서 다른 데이터를 보는 데 심각한 영향을 미친다.logging은 개발자가 정보를 어디로 출력하고 어떻게 출력하는지 결정할 수 있다.
setup_logging(os.path.join(save_path, 'log.txt'))
results_file = os.path.join(save_path, 'results.%s')
results = ResultsLog(results_file % 'csv', results_file % 'html')
​
logging.info("saving to %s", save_path)
logging.debug("run arguments: %s", args)
setup_logging      utils.py   :

def setup_logging(log_file='log.txt'):
    """Setup logging configuration
    """
    logging.basicConfig(level=logging.DEBUG,
                        format="%(asctime)s - %(levelname)s - %(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S",
                        filename=log_file,
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

ResultsLog 클래스도 utils에 있습니다.py에서 구현:
class ResultsLog(object):def __init__(self, path='results.csv', plot_path=None):
        self.path = path
        self.plot_path = plot_path or (self.path + '.html')
        self.figures = []
        self.results = Nonedef add(self, **kwargs):
        df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
        if self.results is None:
            self.results = df
        else:
            self.results = self.results.append(df, ignore_index=True)def save(self, title='Training Results'):
        if len(self.figures) > 0:
            if os.path.isfile(self.plot_path):
                os.remove(self.plot_path)
            output_file(self.plot_path, title=title)
            plot = column(*self.figures)
            save(plot)
            self.figures = []
        self.results.to_csv(self.path, index=False, index_label=False)def load(self, path=None):
        path = path or self.path
        if os.path.isfile(path):
            self.results.read_csv(path)def show(self):
        if len(self.figures) > 0:
            plot = column(*self.figures)
            show(plot)#def plot(self, *kargs, **kwargs):
    #    line = Line(data=self.results, *kargs, **kwargs)
    #    self.figures.append(line)def image(self, *kargs, **kwargs):
        fig = figure()
        fig.image(*kargs, **kwargs)
        self.figures.append(fig)

로그에 저장된 경로와 모든 입력 매개 변수를 출력합니다.
if 'cuda' in args.type:
    args.gpus = [int(i) for i in args.gpus.split(',')]
    torch.cuda.set_device(args.gpus[0])
    cudnn.benchmark = True
else:
    args.gpus = None

여기는 cuda gpu 같은 것 같으니 잠시 건너가세요.
다음은 모델 생성을 시작합니다.
    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))
​
    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

ast 패키지
위의literaleval 함수는ast 패키지의 함수입니다. 간단하게 말하면ast 모듈은Python 응용을 도와 추상적인 문법 해석을 처리하는 것입니다.이 모듈의 literaleval () 함수: 계산이 필요한 내용을 계산한 후 합법적인python 형식인지 판단하고, 만약 그렇다면 연산을 하고, 그렇지 않으면 연산을 하지 않습니다.
이전 코드의 끝에서 모델이 호출되었고 앞에 있는 코드에서 모델은 다음과 같이 정의됩니다.
 model = models.__dict__[args.model]

모델을 호출하면 이 모델의 모든 속성을 되돌려줍니다.
    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)",
                     args.evaluate, checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(
                checkpoint_file, 'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)",
                         checkpoint_file, checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

이 부분은 이미 있는 인자를 불러옵니다. (있으면)
다음은 데이터 로드 섹션입니다.
    # Data loading code
    default_transform = {
        'train': get_transform(args.dataset,
                               input_size=args.input_size, augment=True),
        'eval': get_transform(args.dataset,
                              input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
                                           'lr': args.lr,
                                           'momentum': args.momentum,
                                           'weight_decay': args.weight_decay}})

이 부분의 함수는preprocess에 있습니다.py에서 실현되면 구체적인 것은 쓰지 않는다. 왜냐하면 이것은 데이터 처리 블로그가 아니기 때문이다.
    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)
    model.type(args.type)

이 부분은 손실 함수의 정의다.
val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(
        val_data,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return
​
    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
​
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    logging.info('training regime: %s', regime)

데이터 세트 로드 및 최적화 방법 섹션에서 다음 훈련을 시작합니다.
    for epoch in range(args.start_epoch, args.epochs):
        optimizer = adjust_optimizer(optimizer, epoch, regime)# train for one epoch
        train_loss, train_prec1, train_prec5 = train(
            train_loader, model, criterion, epoch, optimizer)# evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(
            val_loader, model, criterion, epoch)# remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
​
        save_checkpoint({
            'epoch': epoch + 1,
            'model': args.model,
            'config': args.model_config,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'regime': regime
        }, is_best, path=save_path)
        logging.info('
Epoch: {0}\t'
'Training Loss {train_loss:.4f} \t' 'Training Prec@1 {train_prec1:.3f} \t' 'Training Prec@5 {train_prec5:.3f} \t' 'Validation Loss {val_loss:.4f} \t' 'Validation Prec@1 {val_prec1:.3f} \t' 'Validation Prec@5 {val_prec5:.3f}
'
.format(epoch + 1, train_loss=train_loss, val_loss=val_loss, train_prec1=train_prec1, val_prec1=val_prec1, train_prec5=train_prec5, val_prec5=val_prec5)) ​ results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, train_error1=100 - train_prec1, val_error1=100 - val_prec1, train_error5=100 - train_prec5, val_error5=100 - val_prec5) #results.plot(x='epoch', y=['train_loss', 'val_loss'], # title='Loss', ylabel='loss') #results.plot(x='epoch', y=['train_error1', 'val_error1'], # title='Error@1', ylabel='error %') #results.plot(x='epoch', y=['train_error5', 'val_error5'], # title='Error@5', ylabel='error %') results.save()

코드가 길지 않아 모두 이해할 수 있는데 주로 아래의 두 문장과 그들이 가지고 나온 함수를 해석한 것이다.
 # train for one epoch
        train_loss, train_prec1, train_prec5 = train(
            train_loader, model, criterion, epoch, optimizer)# evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(
            val_loader, model, criterion, epoch)
def train(data_loader, model, criterion, epoch, optimizer):
    # switch to train mode
    model.train()
    return forward(data_loader, model, criterion, epoch,
                   training=True, optimizer=optimizer)
​
​
def validate(data_loader, model, criterion, epoch):
    # switch to evaluate mode
    model.eval()
    return forward(data_loader, model, criterion, epoch,
                   training=False, optimizer=None)

다음은 forward 함수에 대한 해독입니다.
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
    if args.gpus and len(args.gpus) > 1:
        model = torch.nn.DataParallel(model, args.gpus)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

이 Average Meter ()는 무엇입니까,utils에 있습니다.py에서 찾았습니다:
class AverageMeter(object):
    """Computes and stores the average and current value"""def __init__(self):
        self.reset()def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
​
__optimizers = {
    'SGD': torch.optim.SGD,
    'ASGD': torch.optim.ASGD,
    'Adam': torch.optim.Adam,
    'Adamax': torch.optim.Adamax,
    'Adagrad': torch.optim.Adagrad,
    'Adadelta': torch.optim.Adadelta,
    'Rprop': torch.optim.Rprop,
    'RMSprop': torch.optim.RMSprop
}

저자의 주석은 이미 매우 분명하게 말했다: Computes and stores the average and current value.
end = time.time()
    for i, (inputs, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        if args.gpus is not None:
            target = target.cuda(async=True)
        input_var = Variable(inputs.type(args.type), volatile=not training)
        target_var = Variable(target)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        if type(output) is list:
            output = output[0]
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

이 코드는 데이터로부터loader에서 데이터를 읽고 모델에 입력하여 정확도를 계산하고loss값을 업데이트합니다.
optimizer.step()
        if training:
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            for p in list(model.parameters()):
                if hasattr(p,'org'):
                    p.data.copy_(p.org)
            optimizer.step()
            for p in list(model.parameters()):
                if hasattr(p,'org'):
                    p.org.copy_(p.data.clamp_(-1,1))

이 단계에서 먼저 역방향 전파를 한 다음에 마침내 BNN과 관련된 곳에 도착했다. (이 블로그를 쓴 취지를 잊어버린 것은 말할 것도 없다) 여기서 역방향 전파가 사다리를 계산한 후에 사다리를 갱신하기 전에 모델 파라미터를 원래의 정밀도로 복원하고 새로워진 후에 파라미터를 (-1,1) 구간에 제한한 것을 발견했다.
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()if i % args.print_freq == 0:
            logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                         'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch, i, len(data_loader),
                             phase='TRAINING' if training else 'EVALUATING',
                             batch_time=batch_time,
                             data_time=data_time, loss=losses, top1=top1, top5=top5))return losses.avg, top1.avg, top5.avg

이 부분은 훈련 메시지를 저장하고 로그에 저장한 코드로 더 이상 상세하게 설명하지 않는다.
먼저 여기까지 쓰고 나머지는 내일 쓰세요~

2.2 models


2.2.1 binarized_modules.py

def Binarize(tensor,quant_mode='det'):
    if quant_mode=='det':
        return tensor.sign()
    else:
        return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

이치화 함수, 입력 벡터, 반환 이치화 벡터는 논문에서 기술한 바와 같이 무작위 이치화와 확정 이치화를 포함한다.
def Quantize(tensor,quant_mode='det',  params=None, numBits=8):
    tensor.clamp_(-2**(numBits-1),2**(numBits-1))
    if quant_mode=='det':
        tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
    else:
        tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
        quant_fixed(tensor, params)
    return tensor

import torch.nn._functions as tnnf

양적 함수, 벡터를 지정한 정밀도로 양적화합니다.
class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):

        if input.size(1) != 784:
            input.data=Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)

        return out

2치화 전체 연결층, 1층 입력은 양적화하지 않고 권중을 양적화하지만 초기 권중을 보존한다.
class BinarizeConv2d(nn.Conv2d):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeConv2d, self).__init__(*kargs, **kwargs)


    def forward(self, input):
        if input.size(1) != 3:
            input.data = Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)

        out = nn.functional.conv2d(input, self.weight, None, self.stride,
                                   self.padding, self.dilation, self.groups)

        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1, 1, 1).expand_as(out)

        return out

이치화 권적층은 1층의 입력을 양화하지 않고 권중을 양화하지만 초기 권중을 보존한다.

2.2.2 alexnet.py & alexnet_binary.py


2.2.3 resnet.py & resnet_binary.py


2.2.4 vgg_cifar10.py & vgg_cifar10_binary.py


위의 세 부분은 쓸 필요가 없다. 바로 일반적인 네트워크와 조금 다르다. 함수를 활성화하는 부분에 Hardtanh를 사용한다. 이 부분의 작용은sign 함수를 이완시키는 것이다. 그렇지 않으면 사다리가 모두 0이어서 역방향으로 전파할 수 없다.

3.Conclusion


BNN은 인터넷 압축 양적 측면에서 아주 고전적인 창작품입니다. 논문을 읽고 나서 한 편의 코드를 자세히 연구한 결과 많은 것을 얻었습니다. 하지만 이 논문의 재현 난이도가 매우 낮습니다. 제가 코드를 읽는 것도 프로젝트 구조를 배우는 것일 뿐입니다. 그래서 양적 분야의 논문 코드는 많이 봐야 합니다. 이것은 시작일 뿐입니다. 계속 힘내세요.

좋은 웹페이지 즐겨찾기