pytorch Resnet50 분류 모델 가지 자르기

Resnet50


네트워크 구조:https://www.jianshu.com/p/993c03c22d52

가지치기


1. 넷워크-slimming 논문 기반 방법:pytorch 버전 코드:https://github.com/Eric-mingjie/network-slimming사고방식:downsample 안의 BN층을 제거하고Resnetv2의 구조를 편리하게 사용하기 위해: BN-Conv-ReLU, 각bottleneck의 첫 번째 BN에서 하나의 채널 선택층(전 1층)을 사용자 정의한다. 훈련 과정에서 영향을 주지 않기 위해 가지를 자를 때 먼저 BN의 채널 mask를 생성하고mask에 따라 채널 선택층에 값을 부여하며 이 BN층의 보존 채널을 Conv의 입력으로 선택하고 다음 BN의 mask에 따라 채널을 Conv의 출력 채널로 선택한다.이렇게 층층이 가지를 자른 후의 네트워크를 순환시켜서finetune을 하거나 처음부터 훈련을 한다.2. 맏이가 최근에 개발한 torch 가지치기 도구를 기반으로 한다.https://github.com/VainF/Torch-Pruning사고방식: 가지를 자르기 전에 전체 네트워크의 각 층의 의존 관계를 구축하고 torch의 hooks 메커니즘에 따라 전방향으로 전파되는 모뎀의grad 를 획득한다.fn, 모듈러 대응 노드 node 구축, 각 노드마다 모듈러,grad 포함fn、inputs 、outputs、dependencies、node_module의 inputs와outputs가 의존하는 층 (연산) 을 가져오고 가지치기를 실행할 때 의존 관계에 따라 채널을 자동으로 정렬합니다.hooks 메커니즘:https://cloud.tencent.com/developer/article/1122582、https://zhuanlan.zhihu.com/p/75054200

가지치기 핵심 코드


1. 넷워크-slimming 논문을 바탕으로 하는 방법:
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from Resnet import *
import os
import torchvision
from tqdm import tqdm
from channel_selection import *


# Prune settings
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cat_dog',
                    help='training dataset (default: cat_dog)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                    help='input batch size for testing (default: 8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--depth', type=int, default=164,
                    help='depth of the resnet')
parser.add_argument('--percent', type=float, default=0.3,
                    help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='logs/model_pruning_final.pth', type=str, metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save', default='logs', type=str, metavar='PATH',
                    help='path to save pruned model (default: none)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

DEVICE = torch.device('cuda:1')
LR = 0.0001
EPOCH = 50
BTACH_SIZE = 100
train_root = './train'
vaild_root = './test'

#       
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
    # transforms.RandomHorizontalFlip(),
    # torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    # torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

vaild_data = torchvision.datasets.ImageFolder(
        root=vaild_root,
        transform=test_transform
    )

test_set = torch.utils.data.DataLoader(
    vaild_data,
    batch_size=BTACH_SIZE,
    shuffle=False
)
criteration = nn.CrossEntropyLoss()

model = resnet(depth=args.depth, dataset=args.dataset).to(DEVICE)
model.load_state_dict(torch.load(args.model))

def vaild(model,device,dataset):
    model.eval()
    correct = 0
    with torch.no_grad():
        for i,(x,y) in tqdm(enumerate(dataset)):
            x,y = x.to(device) ,y.to(device)
            output = model(x)
            loss = criteration(output,y)
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
    return 100*correct/(len(dataset)*BTACH_SIZE)
    print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*BTACH_SIZE,100*correct/(len(dataset)*BTACH_SIZE)))

acc = vaild(model,DEVICE,test_set)
print("preprune acc:",acc)
# model = resnet(164, dataset="cat_dog").to(DEVICE)
# model.load_state_dict(torch.load('model_pruning_test.pth'))
total = 0

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre.to(DEVICE)).float().to(DEVICE)
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))

pruned_ratio = pruned/total

print('Pre-processing Successful!',"pruned_ratio:",pruned_ratio)

# simple test model after Pre-processing prune (simple set BN scales to zeros)


print("Cfg:")
print(cfg,len(cfg))

newmodel = resnet(depth=args.depth, dataset=args.dataset, cfg=cfg)
if args.cuda:
    newmodel.to(DEVICE)

num_parameters = sum([param.nelement() for param in newmodel.parameters()])
savepath = os.path.join(args.save, "prune_0.3.txt")
with open(savepath, "w") as fp:
    fp.write("Configuration: 
"
+str(cfg)+"
"
) fp.write("Number of parameters:
"
+str(num_parameters)+"
"
) #fp.write("Test accuracy:
"+str(acc))
old_modules = list(model.modules()) new_modules = list(newmodel.modules()) layer_id_in_cfg = 0 start_mask = torch.ones(3) end_mask = cfg_mask[layer_id_in_cfg] conv_count = 0 for layer_id in range(len(old_modules)): m0 = old_modules[layer_id] m1 = new_modules[layer_id] if isinstance(m0, nn.BatchNorm2d): idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) if idx1.size == 1: idx1 = np.resize(idx1,(1,)) if isinstance(old_modules[layer_id + 1], channel_selection): # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned. m1.weight.data = m0.weight.data.clone() m1.bias.data = m0.bias.data.clone() m1.running_mean = m0.running_mean.clone() m1.running_var = m0.running_var.clone() # We need to set the channel selection layer. m2 = new_modules[layer_id + 1] m2.indexes.data.zero_() m2.indexes.data[idx1.tolist()] = 1.0 layer_id_in_cfg += 1 start_mask = end_mask.clone() if layer_id_in_cfg < len(cfg_mask): end_mask = cfg_mask[layer_id_in_cfg] else: m1.weight.data = m0.weight.data[idx1.tolist()].clone() m1.bias.data = m0.bias.data[idx1.tolist()].clone() m1.running_mean = m0.running_mean[idx1.tolist()].clone() m1.running_var = m0.running_var[idx1.tolist()].clone() layer_id_in_cfg += 1 start_mask = end_mask.clone() if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC end_mask = cfg_mask[layer_id_in_cfg] elif isinstance(m0, nn.Conv2d): if conv_count == 0: m1.weight.data = m0.weight.data.clone() conv_count += 1 continue if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d): # This convers the convolutions in the residual block. # The convolutions are either after the channel selection layer or after the batch normalization layer. conv_count += 1 idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) if idx1.size == 1: idx1 = np.resize(idx1, (1,)) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # If the current convolution is not the last convolution in the residual block, then we can change the # number of output channels. Currently we use `conv_count` to detect whether it is such convolution. if conv_count % 3 != 1: w1 = w1[idx1.tolist(), :, :, :].clone() m1.weight.data = w1.clone() continue # We need to consider the case where there are downsampling convolutions. # For these convolutions, we just copy the weights. m1.weight.data = m0.weight.data.clone() elif isinstance(m0, nn.Linear): idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) if idx0.size == 1: idx0 = np.resize(idx0, (1,)) m1.weight.data = m0.weight.data[:, idx0].clone() m1.bias.data = m0.bias.data.clone() torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned_0.3.pth.tar')) print(newmodel) model = newmodel acc=vaild(model,DEVICE,test_set) print("pruned acc:",acc)

2. Torch-Pruning 가지치기 도구 기반
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch_pruning as tp


parser = argparse.ArgumentParser()
parser.add_argument('--train_root', type=str, default='/data/xywang/dataset/catdog_classification/train',
                    help='training dataset (default: train)')
parser.add_argument('--vaild_root', type=str, default='/data/xywang/dataset/catdog_classification/test',
                    help='training dataset (default: test)')
parser.add_argument('--sr', default=True, type=bool,
                    help='train with channel sparsity regularization')
parser.add_argument('--s', default=0.0001, type=float, 
                    help='scale sparse rate (default: 0.0001)')
parser.add_argument('--batch_size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of epochs to train (default: 160)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save', default='./models', type=str, metavar='PATH',
                    help='path to save prune model (default: current directory)')
parser.add_argument('--percent',default=0.9, type=float,
                    help='the PATH to the pruned model')

args = parser.parse_args()
device = torch.device('cuda:1')

if not os.path.exists(args.save):
    os.makedirs(args.save)

#       
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
    # transforms.RandomHorizontalFlip(),
    # torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    # torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

train_data =  torchvision.datasets.ImageFolder(
        root=args.train_root,
        transform=train_transform
    )

vaild_data = torchvision.datasets.ImageFolder(
        root=args.vaild_root,
        transform=train_transform
    )

train_set = torch.utils.data.DataLoader(
    train_data,
    batch_size=args.batch_size,
    shuffle=True
)

test_set = torch.utils.data.DataLoader(
    vaild_data,
    batch_size=args.batch_size,
    shuffle=False
)

def updateBN(model, s ,pruning_modules):
    for module in pruning_modules:
        module.weight.grad.data.add_(s * torch.sign(module.weight.data))
#     
criteration = nn.CrossEntropyLoss()
def train(model,device,dataset,optimizer,epoch,pruning_modules):
    model.train().to(device)
    correct = 0
    for i,(x,y) in tqdm(enumerate(dataset)):
        x , y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        pred = output.max(1,keepdim=True)[1]
        correct += pred.eq(y.view_as(pred)).sum().item()
        loss =  criteration(output,y)     
        loss.backward()
        optimizer.step()

        if args.sr:
            updateBN(model,args.s,pruning_modules)
        
    print("Epoch {} Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(epoch,loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
    

def vaild(model,device,dataset):
    model.eval().to(device)
    correct = 0
    with torch.no_grad():
        for i,(x,y) in tqdm(enumerate(dataset)):
            x,y = x.to(device) ,y.to(device)
            output = model(x)
            loss = criteration(output,y)
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
    print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
    return 100*correct/(len(dataset)*args.batch_size)

def get_pruning_modules(model):
    module_list = []
    for module in model.modules():
        if isinstance(module,torchvision.models.resnet.Bottleneck):
            module_list.append(module.bn1)
            module_list.append(module.bn2)
    return module_list

def gather_bn_weights(model,pruning_modules):
    size_list = [module.weight.data.shape[0] for module in model.modules() if module in pruning_modules]
    bn_weights = torch.zeros(sum(size_list))
    index = 0
    for module, size in zip(pruning_modules, size_list):
        bn_weights[index:(index + size)] = module.weight.data.abs().clone()
        index += size

    return bn_weights

def computer_eachlayer_pruned_number(bn_weights,thresh):
    num_list = []
    #print(bn_modules)
    for module in bn_modules:
        num = 0
        #print(module.weight.data.abs(),thresh)
        for data in module.weight.data.abs():
            if thresh > data.float():
                num +=1
        num_list.append(num)
    print(thresh)
    return num_list

def prune_model(model,num_list):
    model.to(device)
    DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) )
    def prune_bn(bn, num):
        L1_norm = bn.weight.detach().cpu().numpy()
        prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index)
        plan.exec()
    
    blk_id = 0
    for m in model.modules():
        if isinstance( m, torchvision.models.resnet.Bottleneck ):
            prune_bn( m.bn1, num_list[blk_id] )
            prune_bn( m.bn2, num_list[blk_id+1] )
            blk_id+=2
    return model  


model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(
        nn.Linear(2048,2)
    )
model.to(device)
model.load_state_dict(torch.load("models/model_pruning.pth"))
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

bn_modules = get_pruning_modules(model)

bn_weights = gather_bn_weights(model,bn_modules)
sorted_bn = torch.sort(bn_weights)[0]
sorted_bn, sorted_index = torch.sort(bn_weights)
thresh_index = int(len(bn_weights) * args.percent)
thresh = sorted_bn[thresh_index].to(device)

num_list = computer_eachlayer_pruned_number(bn_weights,thresh)

prune_model(model,num_list)
print(model)

prec = vaild(model,device,test_set)
for epoch in range(1,args.epochs + 1):
    train(model,device,train_set,optimizer,epoch,bn_modules)
    vaild(model,device,test_set)
    # torch.save(model.state_dict(), 'model_pruned.pth')
    torch.save(model, 'models/model_pruned_0.8.pth' )

좋은 웹페이지 즐겨찾기