[주방장해우] 제로에서 FCOS 실현(셋째): 회귀 예측 전환, decode 디코딩
문서 목록
회귀 예측 전환
FCOS의 회귀 헤드는 l, t, r, b의 로그 매끄러운 값을 예측한다.테스트를 할 때 먼저 이 값들을exp조작한 다음에 대응점의 좌표와 l, t, r, b값을 계산하면 진실한box 예측 좌표를 얻을 수 있다.
회귀 예측 변환 코드는 다음과 같습니다. def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
points_position):
"""
snap reg preds to pred bboxes
reg_preds:[points_num,4],4:[l,t,r,b]
points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
"""
pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[points_num,4]
return pred_bboxes
decode 디코딩
FCOS의 decode 디코딩 프로세스는 RetinaNet과 크게 다르지 않습니다.위의 방식에 따라 회귀 예측을 예측한box 좌표로 전환한 후 마찬가지로 NMS를 사용하여 예측 상자를 필터한다.NMS를 하기 전에 분류scores와centerness 예측을 먼저 곱해야 저품질 예측 상자를 필터할 수 있다.scores와centerness를 곱하면 점수가 작아질 것으로 예상되기 때문에 곱한 후에scores에 대한 처방으로 scores 점수를 확대해야 한다.decode 디코딩 코드는 다음과 같습니다.import torch
import torch.nn as nn
class FCOSDecoder(nn.Module):
def __init__(self,
image_w,
image_h,
strides=[8, 16, 32, 64, 128],
top_n=1000,
min_score_threshold=0.01,
nms_threshold=0.6,
max_detection_num=100):
super(FCOSDecoder, self).__init__()
self.image_w = image_w
self.image_h = image_h
self.strides = strides
self.top_n = top_n
self.min_score_threshold = min_score_threshold
self.nms_threshold = nms_threshold
self.max_detection_num = max_detection_num
def forward(self, cls_heads, reg_heads, center_heads, batch_positions):
with torch.no_grad():
device = cls_heads[0].device
filter_scores,filter_score_classes,filter_reg_heads,filter_batch_positions=[],[],[],[]
for per_level_cls_head, per_level_reg_head, per_level_center_head, per_level_position in zip(
cls_heads, reg_heads, center_heads, batch_positions):
per_level_cls_head = torch.sigmoid(per_level_cls_head)
per_level_reg_head = torch.exp(per_level_reg_head)
per_level_center_head = torch.sigmoid(per_level_center_head)
per_level_cls_head = per_level_cls_head.view(
per_level_cls_head.shape[0], -1,
per_level_cls_head.shape[-1])
per_level_reg_head = per_level_reg_head.view(
per_level_reg_head.shape[0], -1,
per_level_reg_head.shape[-1])
per_level_center_head = per_level_center_head.view(
per_level_center_head.shape[0], -1,
per_level_center_head.shape[-1])
per_level_position = per_level_position.view(
per_level_position.shape[0], -1,
per_level_position.shape[-1])
scores, score_classes = torch.max(per_level_cls_head, dim=2)
scores = torch.sqrt(scores * per_level_center_head.squeeze(-1))
if scores.shape[1] >= self.top_n:
scores, indexes = torch.topk(scores,
self.top_n,
dim=1,
largest=True,
sorted=True)
score_classes = torch.gather(score_classes, 1, indexes)
per_level_reg_head = torch.gather(
per_level_reg_head, 1,
indexes.unsqueeze(-1).repeat(1, 1, 4))
per_level_center_head = torch.gather(
per_level_center_head, 1,
indexes.unsqueeze(-1).repeat(1, 1, 1))
per_level_position = torch.gather(
per_level_position, 1,
indexes.unsqueeze(-1).repeat(1, 1, 2))
filter_scores.append(scores)
filter_score_classes.append(score_classes)
filter_reg_heads.append(per_level_reg_head)
filter_batch_positions.append(per_level_position)
filter_scores = torch.cat(filter_scores, axis=1)
filter_score_classes = torch.cat(filter_score_classes, axis=1)
filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
filter_batch_positions = torch.cat(filter_batch_positions, axis=1)
batch_scores, batch_classes, batch_pred_bboxes = [], [], []
for scores, score_classes, per_image_reg_preds, per_image_points_position in zip(
filter_scores, filter_score_classes, filter_reg_heads,
filter_batch_positions):
pred_bboxes = self.snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(
per_image_reg_preds, per_image_points_position)
score_classes = score_classes[
scores > self.min_score_threshold].float()
pred_bboxes = pred_bboxes[
scores > self.min_score_threshold].float()
scores = scores[scores > self.min_score_threshold].float()
one_image_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if scores.shape[0] != 0:
# Sort boxes
sorted_scores, sorted_indexes = torch.sort(scores,
descending=True)
sorted_score_classes = score_classes[sorted_indexes]
sorted_pred_bboxes = pred_bboxes[sorted_indexes]
keep = nms(sorted_pred_bboxes, sorted_scores,
self.nms_threshold)
keep_scores = sorted_scores[keep]
keep_classes = sorted_score_classes[keep]
keep_pred_bboxes = sorted_pred_bboxes[keep]
final_detection_num = min(self.max_detection_num,
keep_scores.shape[0])
one_image_scores[0:final_detection_num] = keep_scores[
0:final_detection_num]
one_image_classes[0:final_detection_num] = keep_classes[
0:final_detection_num]
one_image_pred_bboxes[
0:final_detection_num, :] = keep_pred_bboxes[
0:final_detection_num, :]
one_image_scores = one_image_scores.unsqueeze(0)
one_image_classes = one_image_classes.unsqueeze(0)
one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)
batch_scores.append(one_image_scores)
batch_classes.append(one_image_classes)
batch_pred_bboxes.append(one_image_pred_bboxes)
batch_scores = torch.cat(batch_scores, axis=0)
batch_classes = torch.cat(batch_classes, axis=0)
batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
# batch_scores shape:[batch_size,max_detection_num]
# batch_classes shape:[batch_size,max_detection_num]
# batch_pred_bboxes shape[batch_size,max_detection_num,4]
return batch_scores, batch_classes, batch_pred_bboxes
def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
points_position):
"""
snap reg preds to pred bboxes
reg_preds:[points_num,4],4:[l,t,r,b]
points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
"""
pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[points_num,4]
return pred_bboxes
if __name__ == '__main__':
from fcos import FCOS
net = FCOS(resnet_type="resnet50")
image_h, image_w = 600, 600
cls_heads, reg_heads, center_heads, batch_positions = net(
torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
[13, 45, 175, 210, 2]],
[[11, 18, 223, 225, 1],
[-1, -1, -1, -1, -1]],
[[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]]])
decode = FCOSDecoder(image_w, image_h)
batch_scores2, batch_classes2, batch_pred_bboxes2 = decode(
cls_heads, reg_heads, center_heads, batch_positions)
print("2222", batch_scores2.shape, batch_classes2.shape,
batch_pred_bboxes2.shape)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
다양한 언어의 JSON
JSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다.
그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다.
저는 일반적으로 '객체'{}...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.
def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
points_position):
"""
snap reg preds to pred bboxes
reg_preds:[points_num,4],4:[l,t,r,b]
points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
"""
pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[points_num,4]
return pred_bboxes
FCOS의 decode 디코딩 프로세스는 RetinaNet과 크게 다르지 않습니다.위의 방식에 따라 회귀 예측을 예측한box 좌표로 전환한 후 마찬가지로 NMS를 사용하여 예측 상자를 필터한다.NMS를 하기 전에 분류scores와centerness 예측을 먼저 곱해야 저품질 예측 상자를 필터할 수 있다.scores와centerness를 곱하면 점수가 작아질 것으로 예상되기 때문에 곱한 후에scores에 대한 처방으로 scores 점수를 확대해야 한다.decode 디코딩 코드는 다음과 같습니다.
import torch
import torch.nn as nn
class FCOSDecoder(nn.Module):
def __init__(self,
image_w,
image_h,
strides=[8, 16, 32, 64, 128],
top_n=1000,
min_score_threshold=0.01,
nms_threshold=0.6,
max_detection_num=100):
super(FCOSDecoder, self).__init__()
self.image_w = image_w
self.image_h = image_h
self.strides = strides
self.top_n = top_n
self.min_score_threshold = min_score_threshold
self.nms_threshold = nms_threshold
self.max_detection_num = max_detection_num
def forward(self, cls_heads, reg_heads, center_heads, batch_positions):
with torch.no_grad():
device = cls_heads[0].device
filter_scores,filter_score_classes,filter_reg_heads,filter_batch_positions=[],[],[],[]
for per_level_cls_head, per_level_reg_head, per_level_center_head, per_level_position in zip(
cls_heads, reg_heads, center_heads, batch_positions):
per_level_cls_head = torch.sigmoid(per_level_cls_head)
per_level_reg_head = torch.exp(per_level_reg_head)
per_level_center_head = torch.sigmoid(per_level_center_head)
per_level_cls_head = per_level_cls_head.view(
per_level_cls_head.shape[0], -1,
per_level_cls_head.shape[-1])
per_level_reg_head = per_level_reg_head.view(
per_level_reg_head.shape[0], -1,
per_level_reg_head.shape[-1])
per_level_center_head = per_level_center_head.view(
per_level_center_head.shape[0], -1,
per_level_center_head.shape[-1])
per_level_position = per_level_position.view(
per_level_position.shape[0], -1,
per_level_position.shape[-1])
scores, score_classes = torch.max(per_level_cls_head, dim=2)
scores = torch.sqrt(scores * per_level_center_head.squeeze(-1))
if scores.shape[1] >= self.top_n:
scores, indexes = torch.topk(scores,
self.top_n,
dim=1,
largest=True,
sorted=True)
score_classes = torch.gather(score_classes, 1, indexes)
per_level_reg_head = torch.gather(
per_level_reg_head, 1,
indexes.unsqueeze(-1).repeat(1, 1, 4))
per_level_center_head = torch.gather(
per_level_center_head, 1,
indexes.unsqueeze(-1).repeat(1, 1, 1))
per_level_position = torch.gather(
per_level_position, 1,
indexes.unsqueeze(-1).repeat(1, 1, 2))
filter_scores.append(scores)
filter_score_classes.append(score_classes)
filter_reg_heads.append(per_level_reg_head)
filter_batch_positions.append(per_level_position)
filter_scores = torch.cat(filter_scores, axis=1)
filter_score_classes = torch.cat(filter_score_classes, axis=1)
filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
filter_batch_positions = torch.cat(filter_batch_positions, axis=1)
batch_scores, batch_classes, batch_pred_bboxes = [], [], []
for scores, score_classes, per_image_reg_preds, per_image_points_position in zip(
filter_scores, filter_score_classes, filter_reg_heads,
filter_batch_positions):
pred_bboxes = self.snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(
per_image_reg_preds, per_image_points_position)
score_classes = score_classes[
scores > self.min_score_threshold].float()
pred_bboxes = pred_bboxes[
scores > self.min_score_threshold].float()
scores = scores[scores > self.min_score_threshold].float()
one_image_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if scores.shape[0] != 0:
# Sort boxes
sorted_scores, sorted_indexes = torch.sort(scores,
descending=True)
sorted_score_classes = score_classes[sorted_indexes]
sorted_pred_bboxes = pred_bboxes[sorted_indexes]
keep = nms(sorted_pred_bboxes, sorted_scores,
self.nms_threshold)
keep_scores = sorted_scores[keep]
keep_classes = sorted_score_classes[keep]
keep_pred_bboxes = sorted_pred_bboxes[keep]
final_detection_num = min(self.max_detection_num,
keep_scores.shape[0])
one_image_scores[0:final_detection_num] = keep_scores[
0:final_detection_num]
one_image_classes[0:final_detection_num] = keep_classes[
0:final_detection_num]
one_image_pred_bboxes[
0:final_detection_num, :] = keep_pred_bboxes[
0:final_detection_num, :]
one_image_scores = one_image_scores.unsqueeze(0)
one_image_classes = one_image_classes.unsqueeze(0)
one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)
batch_scores.append(one_image_scores)
batch_classes.append(one_image_classes)
batch_pred_bboxes.append(one_image_pred_bboxes)
batch_scores = torch.cat(batch_scores, axis=0)
batch_classes = torch.cat(batch_classes, axis=0)
batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
# batch_scores shape:[batch_size,max_detection_num]
# batch_classes shape:[batch_size,max_detection_num]
# batch_pred_bboxes shape[batch_size,max_detection_num,4]
return batch_scores, batch_classes, batch_pred_bboxes
def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
points_position):
"""
snap reg preds to pred bboxes
reg_preds:[points_num,4],4:[l,t,r,b]
points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
"""
pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[points_num,4]
return pred_bboxes
if __name__ == '__main__':
from fcos import FCOS
net = FCOS(resnet_type="resnet50")
image_h, image_w = 600, 600
cls_heads, reg_heads, center_heads, batch_positions = net(
torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
[13, 45, 175, 210, 2]],
[[11, 18, 223, 225, 1],
[-1, -1, -1, -1, -1]],
[[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]]])
decode = FCOSDecoder(image_w, image_h)
batch_scores2, batch_classes2, batch_pred_bboxes2 = decode(
cls_heads, reg_heads, center_heads, batch_positions)
print("2222", batch_scores2.shape, batch_classes2.shape,
batch_pred_bboxes2.shape)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
다양한 언어의 JSONJSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다. 그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다. 저는 일반적으로 '객체'{}...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.