pytorch Resnet50 분류 모델 가지 자르기
Resnet50
네트워크 구조:https://www.jianshu.com/p/993c03c22d52
가지치기
1. 넷워크-slimming 논문 기반 방법:pytorch 버전 코드:https://github.com/Eric-mingjie/network-slimming사고방식:downsample 안의 BN층을 제거하고Resnetv2의 구조를 편리하게 사용하기 위해: BN-Conv-ReLU, 각bottleneck의 첫 번째 BN에서 하나의 채널 선택층(전 1층)을 사용자 정의한다. 훈련 과정에서 영향을 주지 않기 위해 가지를 자를 때 먼저 BN의 채널 mask를 생성하고mask에 따라 채널 선택층에 값을 부여하며 이 BN층의 보존 채널을 Conv의 입력으로 선택하고 다음 BN의 mask에 따라 채널을 Conv의 출력 채널로 선택한다.이렇게 층층이 가지를 자른 후의 네트워크를 순환시켜서finetune을 하거나 처음부터 훈련을 한다.2. 맏이가 최근에 개발한 torch 가지치기 도구를 기반으로 한다.https://github.com/VainF/Torch-Pruning사고방식: 가지를 자르기 전에 전체 네트워크의 각 층의 의존 관계를 구축하고 torch의 hooks 메커니즘에 따라 전방향으로 전파되는 모뎀의grad 를 획득한다.fn, 모듈러 대응 노드 node 구축, 각 노드마다 모듈러,grad 포함fn、inputs 、outputs、dependencies、node_module의 inputs와outputs가 의존하는 층 (연산) 을 가져오고 가지치기를 실행할 때 의존 관계에 따라 채널을 자동으로 정렬합니다.hooks 메커니즘:https://cloud.tencent.com/developer/article/1122582、https://zhuanlan.zhihu.com/p/75054200
가지치기 핵심 코드
1. 넷워크-slimming 논문을 바탕으로 하는 방법:
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from Resnet import *
import os
import torchvision
from tqdm import tqdm
from channel_selection import *
# Prune settings
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cat_dog',
help='training dataset (default: cat_dog)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
help='input batch size for testing (default: 8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--depth', type=int, default=164,
help='depth of the resnet')
parser.add_argument('--percent', type=float, default=0.3,
help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='logs/model_pruning_final.pth', type=str, metavar='PATH',
help='path to the model (default: none)')
parser.add_argument('--save', default='logs', type=str, metavar='PATH',
help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if not os.path.exists(args.save):
os.makedirs(args.save)
DEVICE = torch.device('cuda:1')
LR = 0.0001
EPOCH = 50
BTACH_SIZE = 100
train_root = './train'
vaild_root = './test'
#
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
test_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
# transforms.RandomHorizontalFlip(),
# torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
# torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
vaild_data = torchvision.datasets.ImageFolder(
root=vaild_root,
transform=test_transform
)
test_set = torch.utils.data.DataLoader(
vaild_data,
batch_size=BTACH_SIZE,
shuffle=False
)
criteration = nn.CrossEntropyLoss()
model = resnet(depth=args.depth, dataset=args.dataset).to(DEVICE)
model.load_state_dict(torch.load(args.model))
def vaild(model,device,dataset):
model.eval()
correct = 0
with torch.no_grad():
for i,(x,y) in tqdm(enumerate(dataset)):
x,y = x.to(device) ,y.to(device)
output = model(x)
loss = criteration(output,y)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
return 100*correct/(len(dataset)*BTACH_SIZE)
print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*BTACH_SIZE,100*correct/(len(dataset)*BTACH_SIZE)))
acc = vaild(model,DEVICE,test_set)
print("preprune acc:",acc)
# model = resnet(164, dataset="cat_dog").to(DEVICE)
# model.load_state_dict(torch.load('model_pruning_test.pth'))
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre.to(DEVICE)).float().to(DEVICE)
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
pruned_ratio = pruned/total
print('Pre-processing Successful!',"pruned_ratio:",pruned_ratio)
# simple test model after Pre-processing prune (simple set BN scales to zeros)
print("Cfg:")
print(cfg,len(cfg))
newmodel = resnet(depth=args.depth, dataset=args.dataset, cfg=cfg)
if args.cuda:
newmodel.to(DEVICE)
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
savepath = os.path.join(args.save, "prune_0.3.txt")
with open(savepath, "w") as fp:
fp.write("Configuration:
"+str(cfg)+"
")
fp.write("Number of parameters:
"+str(num_parameters)+"
")
#fp.write("Test accuracy:
"+str(acc))
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
if isinstance(old_modules[layer_id + 1], channel_selection):
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data.zero_()
m2.indexes.data[idx1.tolist()] = 1.0
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
continue
# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned_0.3.pth.tar'))
print(newmodel)
model = newmodel
acc=vaild(model,DEVICE,test_set)
print("pruned acc:",acc)
2. Torch-Pruning 가지치기 도구 기반
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch_pruning as tp
parser = argparse.ArgumentParser()
parser.add_argument('--train_root', type=str, default='/data/xywang/dataset/catdog_classification/train',
help='training dataset (default: train)')
parser.add_argument('--vaild_root', type=str, default='/data/xywang/dataset/catdog_classification/test',
help='training dataset (default: test)')
parser.add_argument('--sr', default=True, type=bool,
help='train with channel sparsity regularization')
parser.add_argument('--s', default=0.0001, type=float,
help='scale sparse rate (default: 0.0001)')
parser.add_argument('--batch_size', type=int, default=100, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
help='number of epochs to train (default: 160)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save', default='./models', type=str, metavar='PATH',
help='path to save prune model (default: current directory)')
parser.add_argument('--percent',default=0.9, type=float,
help='the PATH to the pruned model')
args = parser.parse_args()
device = torch.device('cuda:1')
if not os.path.exists(args.save):
os.makedirs(args.save)
#
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
test_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
# transforms.RandomHorizontalFlip(),
# torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
# torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
train_data = torchvision.datasets.ImageFolder(
root=args.train_root,
transform=train_transform
)
vaild_data = torchvision.datasets.ImageFolder(
root=args.vaild_root,
transform=train_transform
)
train_set = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True
)
test_set = torch.utils.data.DataLoader(
vaild_data,
batch_size=args.batch_size,
shuffle=False
)
def updateBN(model, s ,pruning_modules):
for module in pruning_modules:
module.weight.grad.data.add_(s * torch.sign(module.weight.data))
#
criteration = nn.CrossEntropyLoss()
def train(model,device,dataset,optimizer,epoch,pruning_modules):
model.train().to(device)
correct = 0
for i,(x,y) in tqdm(enumerate(dataset)):
x , y = x.to(device), y.to(device)
optimizer.zero_grad()
output = model(x)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
loss = criteration(output,y)
loss.backward()
optimizer.step()
if args.sr:
updateBN(model,args.s,pruning_modules)
print("Epoch {} Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(epoch,loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
def vaild(model,device,dataset):
model.eval().to(device)
correct = 0
with torch.no_grad():
for i,(x,y) in tqdm(enumerate(dataset)):
x,y = x.to(device) ,y.to(device)
output = model(x)
loss = criteration(output,y)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
return 100*correct/(len(dataset)*args.batch_size)
def get_pruning_modules(model):
module_list = []
for module in model.modules():
if isinstance(module,torchvision.models.resnet.Bottleneck):
module_list.append(module.bn1)
module_list.append(module.bn2)
return module_list
def gather_bn_weights(model,pruning_modules):
size_list = [module.weight.data.shape[0] for module in model.modules() if module in pruning_modules]
bn_weights = torch.zeros(sum(size_list))
index = 0
for module, size in zip(pruning_modules, size_list):
bn_weights[index:(index + size)] = module.weight.data.abs().clone()
index += size
return bn_weights
def computer_eachlayer_pruned_number(bn_weights,thresh):
num_list = []
#print(bn_modules)
for module in bn_modules:
num = 0
#print(module.weight.data.abs(),thresh)
for data in module.weight.data.abs():
if thresh > data.float():
num +=1
num_list.append(num)
print(thresh)
return num_list
def prune_model(model,num_list):
model.to(device)
DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) )
def prune_bn(bn, num):
L1_norm = bn.weight.detach().cpu().numpy()
prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm
plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index)
plan.exec()
blk_id = 0
for m in model.modules():
if isinstance( m, torchvision.models.resnet.Bottleneck ):
prune_bn( m.bn1, num_list[blk_id] )
prune_bn( m.bn2, num_list[blk_id+1] )
blk_id+=2
return model
model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(
nn.Linear(2048,2)
)
model.to(device)
model.load_state_dict(torch.load("models/model_pruning.pth"))
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
bn_modules = get_pruning_modules(model)
bn_weights = gather_bn_weights(model,bn_modules)
sorted_bn = torch.sort(bn_weights)[0]
sorted_bn, sorted_index = torch.sort(bn_weights)
thresh_index = int(len(bn_weights) * args.percent)
thresh = sorted_bn[thresh_index].to(device)
num_list = computer_eachlayer_pruned_number(bn_weights,thresh)
prune_model(model,num_list)
print(model)
prec = vaild(model,device,test_set)
for epoch in range(1,args.epochs + 1):
train(model,device,train_set,optimizer,epoch,bn_modules)
vaild(model,device,test_set)
# torch.save(model.state_dict(), 'model_pruned.pth')
torch.save(model, 'models/model_pruned_0.8.pth' )
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Pytorch 딥러닝 프레임워크 YOLOv3 목표 탐지 학습 노트(二) - 네트워크 프레임워크를 만드는 층이 파일은 YOLO 네트워크를 만드는 코드를 포함하고 util을 사용합니다.py 파일에 포함된 여러 가지 유용한 함수 코드는darknet에 대한 것입니다.py 지원, 이 두 파일을 디렉터리에 넣기 공식 코드(c 언어...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.