OpenCV에서 CascadeClassifier 트레이닝
30415 단어 Opencv3 노트
Cascade Classifier Training [OpenCV3] 등급별 분류기 트레이닝-traincascade 빠른 사용 설명
등급 분류 분류기 소개
데이터 준비
업데이트를 하지 않으면 CPU로 훈련하는 속도가 너무 느릴 수밖에 없다. 게다가 인터넷에서는 이런 방법의 정확도가 좋지 않다고 해서 업데이트를 하지 않는다
샘플 및 마이너스 샘플 생성
플러스 마이너스 샘플 생성 스크립트
import sys
import numpy as np
import xml.etree.ElementTree as ET
import cv2
import os
import numpy.random as npr
from utils import IoU
from utils import ensure_directory_exists
save_dir = "/home/rui"
anno_path = "./firepos/annotation"
im_dir = "./firepos/images"
pos_save_dir = os.path.join(save_dir, "./res/positive")
neg_save_dir = os.path.join(save_dir, './res/negative')
ensure_directory_exists(pos_save_dir)
ensure_directory_exists(neg_save_dir)
names_xml = os.listdir(anno_path)
img_rule_h = 45
img_rule_w = 45
size = img_rule_h
num = len(names_xml)
print "%d pics in total" % num
p_idx = 0 # positive
n_idx = 0 # negative
d_idx = 0 # dont care
idx = 0
box_idx = 0
for ne_xml in names_xml:
tree = ET.parse(os.path.join(anno_path, ne_xml))
root = tree.getroot()
loc_bbox = []
width_xml = root.find("size").find("width").text
height_xml = root.find("size").find("height").text
for node in root.findall('object'):
label_ = node.find('name').text
if label_ == "fire":
xmin_ = node.find('bndbox').find('xmin').text
ymin_ = node.find('bndbox').find('ymin').text
xmax_ = node.find('bndbox').find('xmax').text
ymax_ = node.find('bndbox').find('ymax').text
loc_bbox.append(xmin_)
loc_bbox.append(ymin_)
loc_bbox.append(xmax_)
loc_bbox.append(ymax_)
im_path = "{}/{}".format(im_dir, ne_xml.split(".")[0])
if os.path.exists(im_path + ".jpg"):
im_path = "{}.jpg".format(im_path)
else:
im_path = "{}.JPG".format(im_path)
boxes = np.array(loc_bbox, dtype=np.float32).reshape(-1, 4)
img = cv2.imread(im_path)
h, w, c =img.shape
if h != int(height_xml) or w != int(width_xml):
print h, height_xml,w,width_xml
continue
idx += 1
if idx % 100 == 0:
print idx, "images done"
height, width, channel = img.shape
neg_num = 0
while neg_num < 700:
size_new = 0.0
if width > height:
size_new = npr.randint(img_rule_h + 1, max(img_rule_h, height / 2 - 1))
else:
size_new = npr.randint(img_rule_w + 1, max(img_rule_w, width / 2 - 1))
size_new = int(size_new)
nx = npr.randint(0, width - size_new)
ny = npr.randint(0, height - size_new)
crop_box = np.array([nx, ny, nx + size_new, ny + size_new])
Iou = IoU(crop_box, boxes)
cropped_im = img[ny : ny + size_new, nx : nx + size_new, :]
resized_im = cv2.resize(cropped_im, (img_rule_w, img_rule_h), interpolation=cv2.INTER_LINEAR)
if len(Iou) != 0:
if np.max(Iou) < 0.1:
# Iou with all gts must below 0.3
save_file = os.path.join(neg_save_dir, "%s.jpg"%n_idx)
cv2.imwrite(save_file, resized_im)
n_idx += 1
neg_num += 1
else:
# Iou with all gts must below 0.3
save_file = os.path.join(neg_save_dir, "%s.jpg"%n_idx)
cv2.imwrite(save_file, resized_im)
n_idx += 1
neg_num += 1
for box in boxes:
# box (x_left, y_top, x_right, y_bottom)
x1, y1, x2, y2 = box
w = x2 - x1 + 1
h = y2 - y1 + 1
# if float(w) / h < 2:
# continue
# ignore small faces
# in case the ground truth boxes of small faces are not accurate
if w < img_rule_w or h < img_rule_h or x1 < 0 or y1 < 0:
continue
# generate positive examples and part faces
pos_nums = 300
while pos_nums > 0:
size_new = npr.randint(int(pow(w * h, 0.5) - 1), int(max(w, h)))
# delta here is the offset of box center
delta_x = npr.randint(int(-size_new * 0.1), int(size_new * 0.1))
delta_y = npr.randint(int(-size_new * 0.1), int(size_new * 0.1))
nx1 = max(x1 + w / 2 + delta_x - size_new / 2, 0)
ny1 = max(y1 + h / 2 + delta_y - size_new / 2, 0)
nx2 = min(width, nx1 + size_new)
ny2 = min(height, ny1 + size_new)
if nx2 > width or ny2 > height:
continue
crop_box = np.array([nx1, ny1, nx2, ny2])
cropped_im = img[int(ny1) : int(ny2), int(nx1) : int(nx2), :]
resized_im = cv2.resize(cropped_im, (img_rule_w, img_rule_h))
box_ = box.reshape(1, -1)
pos_nums -= 1
save_file = os.path.join(pos_save_dir, "%s.jpg"%p_idx)
cv2.imwrite(save_file, resized_im)
p_idx += 1
box_idx += 1
print "%s images done, pos: %s, neg: %s"%(idx, p_idx, n_idx)
본문 쓰기
import os
pos_dir = "/home/rui/res/positive"
pos_list = os.listdir(pos_dir)
f = open("/home/rui/temp.txt", "w")
for im in pos_list:
name = "positive/{} 1 0 0 45 45
".format(im)
print name
f.writelines(name)
f.close()
find -name *.jpg >> neg.txt
업데이트 대기 중