[주방장해우] 제로에서 FCOS 실현(셋째): 회귀 예측 전환, decode 디코딩

46757 단어

문서 목록

  • 회귀 예측 전환
  • decode 디코딩
  • 모든 코드가 본인github repository에 업로드되었습니다.https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training도움이 된다면 스타를 눌러주세요!다음 코드는pytorch1에 있습니다.4 버전에서 테스트하여 정확하고 틀림없음을 확인하였습니다.

    회귀 예측 전환


    FCOS의 회귀 헤드는 l, t, r, b의 로그 매끄러운 값을 예측한다.테스트를 할 때 먼저 이 값들을exp조작한 다음에 대응점의 좌표와 l, t, r, b값을 계산하면 진실한box 예측 좌표를 얻을 수 있다.
    회귀 예측 변환 코드는 다음과 같습니다.
        def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
                                                      points_position):
            """
            snap reg preds to pred bboxes
            reg_preds:[points_num,4],4:[l,t,r,b]
            points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
            """
            pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
            pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
            pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                    axis=1)
            pred_bboxes = pred_bboxes.int()
    
            pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
            pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
            pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                            max=self.image_w - 1)
            pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                            max=self.image_h - 1)
    
            # pred bboxes shape:[points_num,4]
            return pred_bboxes
    

    decode 디코딩


    FCOS의 decode 디코딩 프로세스는 RetinaNet과 크게 다르지 않습니다.위의 방식에 따라 회귀 예측을 예측한box 좌표로 전환한 후 마찬가지로 NMS를 사용하여 예측 상자를 필터한다.NMS를 하기 전에 분류scores와centerness 예측을 먼저 곱해야 저품질 예측 상자를 필터할 수 있다.scores와centerness를 곱하면 점수가 작아질 것으로 예상되기 때문에 곱한 후에scores에 대한 처방으로 scores 점수를 확대해야 한다.decode 디코딩 코드는 다음과 같습니다.
    import torch
    import torch.nn as nn
    
    
    class FCOSDecoder(nn.Module):
        def __init__(self,
                     image_w,
                     image_h,
                     strides=[8, 16, 32, 64, 128],
                     top_n=1000,
                     min_score_threshold=0.01,
                     nms_threshold=0.6,
                     max_detection_num=100):
            super(FCOSDecoder, self).__init__()
            self.image_w = image_w
            self.image_h = image_h
            self.strides = strides
            self.top_n = top_n
            self.min_score_threshold = min_score_threshold
            self.nms_threshold = nms_threshold
            self.max_detection_num = max_detection_num
    
        def forward(self, cls_heads, reg_heads, center_heads, batch_positions):
            with torch.no_grad():
                device = cls_heads[0].device
    
                filter_scores,filter_score_classes,filter_reg_heads,filter_batch_positions=[],[],[],[]
                for per_level_cls_head, per_level_reg_head, per_level_center_head, per_level_position in zip(
                        cls_heads, reg_heads, center_heads, batch_positions):
                    per_level_cls_head = torch.sigmoid(per_level_cls_head)
                    per_level_reg_head = torch.exp(per_level_reg_head)
                    per_level_center_head = torch.sigmoid(per_level_center_head)
    
                    per_level_cls_head = per_level_cls_head.view(
                        per_level_cls_head.shape[0], -1,
                        per_level_cls_head.shape[-1])
                    per_level_reg_head = per_level_reg_head.view(
                        per_level_reg_head.shape[0], -1,
                        per_level_reg_head.shape[-1])
                    per_level_center_head = per_level_center_head.view(
                        per_level_center_head.shape[0], -1,
                        per_level_center_head.shape[-1])
                    per_level_position = per_level_position.view(
                        per_level_position.shape[0], -1,
                        per_level_position.shape[-1])
    
                    scores, score_classes = torch.max(per_level_cls_head, dim=2)
                    scores = torch.sqrt(scores * per_level_center_head.squeeze(-1))
                    if scores.shape[1] >= self.top_n:
                        scores, indexes = torch.topk(scores,
                                                     self.top_n,
                                                     dim=1,
                                                     largest=True,
                                                     sorted=True)
                        score_classes = torch.gather(score_classes, 1, indexes)
                        per_level_reg_head = torch.gather(
                            per_level_reg_head, 1,
                            indexes.unsqueeze(-1).repeat(1, 1, 4))
                        per_level_center_head = torch.gather(
                            per_level_center_head, 1,
                            indexes.unsqueeze(-1).repeat(1, 1, 1))
                        per_level_position = torch.gather(
                            per_level_position, 1,
                            indexes.unsqueeze(-1).repeat(1, 1, 2))
                    filter_scores.append(scores)
                    filter_score_classes.append(score_classes)
                    filter_reg_heads.append(per_level_reg_head)
                    filter_batch_positions.append(per_level_position)
    
                filter_scores = torch.cat(filter_scores, axis=1)
                filter_score_classes = torch.cat(filter_score_classes, axis=1)
                filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
                filter_batch_positions = torch.cat(filter_batch_positions, axis=1)
    
                batch_scores, batch_classes, batch_pred_bboxes = [], [], []
                for scores, score_classes, per_image_reg_preds, per_image_points_position in zip(
                        filter_scores, filter_score_classes, filter_reg_heads,
                        filter_batch_positions):
                    pred_bboxes = self.snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(
                        per_image_reg_preds, per_image_points_position)
    
                    score_classes = score_classes[
                        scores > self.min_score_threshold].float()
                    pred_bboxes = pred_bboxes[
                        scores > self.min_score_threshold].float()
                    scores = scores[scores > self.min_score_threshold].float()
    
                    one_image_scores = (-1) * torch.ones(
                        (self.max_detection_num, ), device=device)
                    one_image_classes = (-1) * torch.ones(
                        (self.max_detection_num, ), device=device)
                    one_image_pred_bboxes = (-1) * torch.ones(
                        (self.max_detection_num, 4), device=device)
    
                    if scores.shape[0] != 0:
                        # Sort boxes
                        sorted_scores, sorted_indexes = torch.sort(scores,
                                                                   descending=True)
                        sorted_score_classes = score_classes[sorted_indexes]
                        sorted_pred_bboxes = pred_bboxes[sorted_indexes]
    
                        keep = nms(sorted_pred_bboxes, sorted_scores,
                                   self.nms_threshold)
                        keep_scores = sorted_scores[keep]
                        keep_classes = sorted_score_classes[keep]
                        keep_pred_bboxes = sorted_pred_bboxes[keep]
    
                        final_detection_num = min(self.max_detection_num,
                                                  keep_scores.shape[0])
    
                        one_image_scores[0:final_detection_num] = keep_scores[
                            0:final_detection_num]
                        one_image_classes[0:final_detection_num] = keep_classes[
                            0:final_detection_num]
                        one_image_pred_bboxes[
                            0:final_detection_num, :] = keep_pred_bboxes[
                                0:final_detection_num, :]
    
                    one_image_scores = one_image_scores.unsqueeze(0)
                    one_image_classes = one_image_classes.unsqueeze(0)
                    one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)
    
                    batch_scores.append(one_image_scores)
                    batch_classes.append(one_image_classes)
                    batch_pred_bboxes.append(one_image_pred_bboxes)
    
                batch_scores = torch.cat(batch_scores, axis=0)
                batch_classes = torch.cat(batch_classes, axis=0)
                batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
    
                # batch_scores shape:[batch_size,max_detection_num]
                # batch_classes shape:[batch_size,max_detection_num]
                # batch_pred_bboxes shape[batch_size,max_detection_num,4]
                return batch_scores, batch_classes, batch_pred_bboxes
    
        def snap_ltrb_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_preds,
                                                      points_position):
            """
            snap reg preds to pred bboxes
            reg_preds:[points_num,4],4:[l,t,r,b]
            points_position:[points_num,2],2:[point_ctr_x,point_ctr_y]
            """
            pred_bboxes_xy_min = points_position - reg_preds[:, 0:2]
            pred_bboxes_xy_max = points_position + reg_preds[:, 2:4]
            pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                    axis=1)
            pred_bboxes = pred_bboxes.int()
    
            pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
            pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
            pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                            max=self.image_w - 1)
            pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                            max=self.image_h - 1)
    
            # pred bboxes shape:[points_num,4]
            return pred_bboxes
    
    
    if __name__ == '__main__':
        from fcos import FCOS
        net = FCOS(resnet_type="resnet50")
        image_h, image_w = 600, 600
        cls_heads, reg_heads, center_heads, batch_positions = net(
            torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
        annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                          [13, 45, 175, 210, 2]],
                                         [[11, 18, 223, 225, 1],
                                          [-1, -1, -1, -1, -1]],
                                         [[-1, -1, -1, -1, -1],
                                          [-1, -1, -1, -1, -1]]])
        decode = FCOSDecoder(image_w, image_h)
        batch_scores2, batch_classes2, batch_pred_bboxes2 = decode(
            cls_heads, reg_heads, center_heads, batch_positions)
        print("2222", batch_scores2.shape, batch_classes2.shape,
              batch_pred_bboxes2.shape)
    

    좋은 웹페이지 즐겨찾기