[주방장해우] 제로에서 FCOS 구현(二):ground truth 분배와loss 계산

174373 단어

문서 목록

  • Anchor free?Anchor base?
  • FCOS에 대한 ground truth 할당
  • loss계산
  • 전체 로스 코드
  • 모든 코드가 본인github repository에 업로드되었습니다.https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training도움이 된다면 스타를 눌러주세요!다음 코드는pytorch1에 있습니다.4 버전에서 테스트하여 정확하고 틀림없음을 확인하였습니다.

    Anchor free?Anchor base?


    우선 FCOS가 RetinaNet처럼 현식 Anchor를 사용하지 않았다는 것을 명확히 해야 한다.FCOS는 각 FPN level의feature map에 있는 모든 점을 샘플로 한 다음에 샘플이 표시 상자 안이나 표시 상자 밖에 있는지 여부에 따라 이 샘플이 정샘플인지 음샘플인지 결정한다(FCOS에서 무시되지 않은 샘플이 없음을 주의한다).그런 점에서 FCOS는 확실히 Anchor free의 것이다.그러나 FCOS에서ground trurh 분배와 테스트 계산을 할 때feature map의 각 점을 입력 그림(x, y) 좌표로 거꾸로 밀어야 한다는 점에서 FCOS는 완전한free가 아니다. 더욱 정확히 말하면 FCOS는'point based'목표 탐지기이다.우리는 FCOS를 피처 맵에 있는 점마다 은밀한 Anchor의 목표 탐지기만 있는 것으로 볼 수 있다.
    2020년에 새로 출시된 DETR 대상 감지기(https://arxiv.org/pdf/2005.12872.pdf) 목표 검측 임무 검측을 집합 예측 문제로 보고 Transformer를 사용하여 box 집합을 예측했다. NMS와 Anchor/Point의 선험 좌표를 전혀 사용하지 않아도 검측기가 진정으로'free'를 할 수 있고 관심 있는 학생들이 스스로 이해할 수 있다.

    FCOS에 대한 ground truth 할당


    입력 그림에 표시된 여러 개의 상자에 대해 먼저 FPN의 각 단계의 FPN의feature map에 표시된 모든 점을 판단한다. 만약에 어떤 점이 모든 표시 상자 밖에 있다면 이 점은 마이너스 샘플이 된다.이때 나머지 점 중 일부는 여러 치수 상자 안에 동시에 있을 수 있습니다.그리고 각 점의 각 표시 상자에 대한 l, t, r, b(이 점의 거리 상자는 왼쪽, 위, 오른쪽, 아래의 거리)의 최대 값을 취하고 아래의 값 영역 범위에 따라 최대 값이 어느 범위 안에 떨어지면 이 범위에 대응하는 FPN level의feature map의 대응점에 이 상자를 분배한다.
    #         P3、P4、P5、P6、P7     
    INF=100000000
    mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]
    

    위의 단계를 거친 후, 대부분의 점은 단지 하나의 틀에만 분배될 것이다.그러나 두 치수 상자의 크기 차이가 많지 않을 때 두 상자 안에 있는 점이 있습니다.이러한 점에 대해 우리는 중첩 상자의 면적을 계산한 후에 항상 이 점들을 면적이 가장 작은 치수 상자에 분배한다.아래의 실현 코드에서 이 부분의 견본에 대해 나는 행렬 계산 형식으로 라벨 분배를 했다.한 장의 그림에 있는 샘플은 일반적으로 수십 ~ 이삼백 정도에 불과하지만 이 부분의 샘플에 대해 for순환을 사용하여 라벨을 분배하면 훈련 속도가 매우 느려질 수 있으므로 주의해야 한다.분류 라벨에 대해 0을 마이너스 샘플로 하고 1부터 80까지는 80개의 정류로 한다.l, t, r, b와centerness 라벨은 FCOS 논문에서 공식적으로 계산된 것으로 수정되지 않았다.
    ground truth 할당 코드는 다음과 같습니다.
        def get_batch_position_annotations(self, cls_heads, reg_heads,
                                           center_heads, batch_positions,
                                           annotations):
            """
            Assign a ground truth target for each position on feature map
            """
            device = annotations.device
            batch_mi = []
            for reg_head, mi in zip(reg_heads, self.mi):
                mi = torch.tensor(mi).to(device)
                B, H, W, _ = reg_head.shape
                per_level_mi = torch.zeros(B, H, W, 2).to(device)
                per_level_mi = per_level_mi + mi
                batch_mi.append(per_level_mi)
    
            cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
            for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                    cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
                cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
                reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
                center_pred = center_pred.view(center_pred.shape[0], -1,
                                               center_pred.shape[-1])
                per_level_position = per_level_position.view(
                    per_level_position.shape[0], -1, per_level_position.shape[-1])
                per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                                 per_level_mi.shape[-1])
    
                cls_preds.append(cls_pred)
                reg_preds.append(reg_pred)
                center_preds.append(center_pred)
                all_points_position.append(per_level_position)
                all_points_mi.append(per_level_mi)
    
            cls_preds = torch.cat(cls_preds, axis=1)
            reg_preds = torch.cat(reg_preds, axis=1)
            center_preds = torch.cat(center_preds, axis=1)
            all_points_position = torch.cat(all_points_position, axis=1)
            all_points_mi = torch.cat(all_points_mi, axis=1)
    
            batch_targets = []
            for per_image_position, per_image_mi, per_image_annotations in zip(
                    all_points_position, all_points_mi, annotations):
                per_image_annotations = per_image_annotations[
                    per_image_annotations[:, 4] >= 0]
                points_num = per_image_position.shape[0]
    
                if per_image_annotations.shape[0] == 0:
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6], device=device)
                else:
                    annotaion_num = per_image_annotations.shape[0]
                    per_image_gt_bboxes = per_image_annotations[:, 0:4]
                    candidates = torch.zeros([points_num, annotaion_num, 4],
                                             device=device)
                    candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                    per_image_position = per_image_position.unsqueeze(1).repeat(
                        1, annotaion_num, 2)
                    candidates[:, :,
                               0:2] = per_image_position[:, :,
                                                         0:2] - candidates[:, :,
                                                                           0:2]
                    candidates[:, :,
                               2:4] = candidates[:, :,
                                                 2:4] - per_image_position[:, :,
                                                                           2:4]
    
                    candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                    sample_flag = (candidates_min_value[:, :, 0] >
                                   0).int().unsqueeze(-1)
                    # get all negative reg targets which points ctr out of gt box
                    candidates = candidates * sample_flag
    
                    # get all negative reg targets which assign ground turth not in range of mi
                    candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                    per_image_mi = per_image_mi.unsqueeze(1).repeat(
                        1, annotaion_num, 1)
                    m1_negative_flag = (candidates_max_value[:, :, 0] >
                                        per_image_mi[:, :, 0]).int().unsqueeze(-1)
                    candidates = candidates * m1_negative_flag
                    m2_negative_flag = (candidates_max_value[:, :, 0] <
                                        per_image_mi[:, :, 1]).int().unsqueeze(-1)
                    candidates = candidates * m2_negative_flag
    
                    final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                    final_sample_flag = final_sample_flag > 0
                    positive_index = (final_sample_flag == True).nonzero().squeeze(
                        dim=-1)
    
                    # if no assign positive sample
                    if len(positive_index) == 0:
                        del candidates
                        # 6:l,t,r,b,class_index,center-ness_gt
                        per_image_targets = torch.zeros([points_num, 6],
                                                        device=device)
                    else:
                        positive_candidates = candidates[positive_index]
    
                        del candidates
    
                        sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                        sample_box_gts = sample_box_gts.repeat(
                            positive_candidates.shape[0], 1, 1)
                        sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                            -1).unsqueeze(0)
                        sample_class_gts = sample_class_gts.repeat(
                            positive_candidates.shape[0], 1, 1)
    
                        # 6:l,t,r,b,class_index,center-ness_gt
                        per_image_targets = torch.zeros([points_num, 6],
                                                        device=device)
    
                        if positive_candidates.shape[1] == 1:
                            # if only one candidate for each positive sample
                            # assign l,t,r,b,class_index,center_ness_gt ground truth
                            # class_index value from 1 to 80 represent 80 positive classes
                            # class_index value 0 represenet negative class
                            positive_candidates = positive_candidates.squeeze(1)
                            sample_class_gts = sample_class_gts.squeeze(1)
                            per_image_targets[positive_index,
                                              0:4] = positive_candidates
                            per_image_targets[positive_index,
                                              4:5] = sample_class_gts + 1
    
                            l, t, r, b = per_image_targets[
                                positive_index, 0:1], per_image_targets[
                                    positive_index, 1:2], per_image_targets[
                                        positive_index,
                                        2:3], per_image_targets[positive_index,
                                                                3:4]
                            per_image_targets[positive_index, 5:6] = torch.sqrt(
                                (torch.min(l, r) / torch.max(l, r)) *
                                (torch.min(t, b) / torch.max(t, b)))
                        else:
                            # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                            gts_w_h = sample_box_gts[:, :,
                                                     2:4] - sample_box_gts[:, :,
                                                                           0:2]
                            gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                            positive_candidates_value = positive_candidates.sum(
                                axis=2)
    
                            # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                            INF = 100000000
                            inf_tensor = torch.ones_like(gts_area) * INF
                            gts_area = torch.where(
                                torch.eq(positive_candidates_value, 0.),
                                inf_tensor, gts_area)
    
                            # get the smallest object candidate index
                            _, min_index = gts_area.min(axis=1)
                            candidate_indexes = (
                                torch.linspace(1, positive_candidates.shape[0],
                                               positive_candidates.shape[0]) -
                                1).long()
                            final_candidate_reg_gts = positive_candidates[
                                candidate_indexes, min_index, :]
                            final_candidate_cls_gts = sample_class_gts[
                                candidate_indexes, min_index]
    
                            # assign l,t,r,b,class_index,center_ness_gt ground truth
                            per_image_targets[positive_index,
                                              0:4] = final_candidate_reg_gts
                            per_image_targets[positive_index,
                                              4:5] = final_candidate_cls_gts + 1
    
                            l, t, r, b = per_image_targets[
                                positive_index, 0:1], per_image_targets[
                                    positive_index, 1:2], per_image_targets[
                                        positive_index,
                                        2:3], per_image_targets[positive_index,
                                                                3:4]
                            per_image_targets[positive_index, 5:6] = torch.sqrt(
                                (torch.min(l, r) / torch.max(l, r)) *
                                (torch.min(t, b) / torch.max(t, b)))
    
                per_image_targets = per_image_targets.unsqueeze(0)
                batch_targets.append(per_image_targets)
    
            batch_targets = torch.cat(batch_targets, axis=0)
            batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
    
            # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
            return cls_preds, reg_preds, center_preds, batch_targets
    

    loss 계산


    분류loss는 focal loss를 사용하는데 계산 과정은 RetinaNet과 완전히 같다. 단지 견본이 Anchor에서 Point로 바뀌었을 뿐이다.
    분류 loss 코드는 다음과 같습니다.
        def compute_one_image_focal_loss(self, per_image_cls_preds,
                                         per_image_targets):
            """
            compute one image focal loss(cls loss)
            per_image_cls_preds:[points_num,num_classes]
            per_image_targets:[points_num,8]
            """
            per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                              min=self.epsilon,
                                              max=1. - self.epsilon)
            num_classes = per_image_cls_preds.shape[1]
    
            # generate 80 binary ground truth classes for each anchor
            loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                          num_classes=num_classes + 1)
            loss_ground_truth = loss_ground_truth[:, 1:]
            loss_ground_truth = loss_ground_truth.float()
    
            alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
            alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                       alpha_factor, 1. - alpha_factor)
            pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                             1. - per_image_cls_preds)
            focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
    
            bce_loss = -(
                loss_ground_truth * torch.log(per_image_cls_preds) +
                (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
    
            one_image_focal_loss = focal_weight * bce_loss
    
            one_image_focal_loss = one_image_focal_loss.sum()
            positive_points_num = per_image_targets[
                per_image_targets[:, 4] > 0].shape[0]
            # according to the original paper,We divide the focal loss by the number of positive sample anchors
            one_image_focal_loss = one_image_focal_loss / positive_points_num
    
            return one_image_focal_loss
    

    FCOS 논문에서 회귀loss는 IoU loss를 채택했다.여기는 제가 직접 GIOU loss를 사용합니다.회귀loss는 여전히 정본만 계산하기 때문에 예측상자와 진실상자가 교차하지 않는 상황이 존재하지 않는다. 이때 GIoU loss와 IoU loss는 완전히 동일하다.
    회귀loss 코드는 다음과 같습니다.
        def compute_one_image_giou_loss(self, per_image_reg_preds,
                                        per_image_targets):
            """
            compute one image giou loss(reg loss)
            per_image_reg_preds:[points_num,4]
            per_image_targets:[anchor_num,8]
            """
            # only use positive points sample to compute reg loss
            device = per_image_reg_preds.device
            per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
            per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
            positive_points_num = per_image_targets.shape[0]
    
            if positive_points_num == 0:
                return torch.tensor(0.).to(device)
    
            center_ness_targets = per_image_targets[:, 5]
    
            pred_bboxes_xy_min = per_image_targets[:,
                                                   6:8] - per_image_reg_preds[:,
                                                                              0:2]
            pred_bboxes_xy_max = per_image_targets[:,
                                                   6:8] + per_image_reg_preds[:,
                                                                              2:4]
            gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                             0:2]
            gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                             2:4]
    
            pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                    axis=1)
            gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
    
            overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                             0:2])
            overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                              2:4])
            overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                             overlap_area_top_left,
                                             min=0)
            overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
    
            # anchors and annotations convert format to [x1,y1,w,h]
            pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
            gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
    
            # compute anchors_area and annotations_area
            pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
            gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
    
            # compute union_area
            union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
            union_area = torch.clamp(union_area, min=1e-4)
            # compute ious between one image anchors and one image annotations
            ious = overlap_area / union_area
    
            enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                             0:2])
            enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                              2:4])
            enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                             enclose_area_top_left,
                                             min=0)
            enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
            enclose_area = torch.clamp(enclose_area, min=1e-4)
    
            gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
            gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
            # use center_ness_targets as the weight of gious loss
            gious_loss = gious_loss * center_ness_targets
            gious_loss = gious_loss.sum() / positive_points_num
            gious_loss = 2. * gious_loss
    
            return gious_loss
    

    마지막으로 2를 곱하는 것은 회귀loss와 기타loss의 수량급을 균형 있게 하기 위한 것이다.
    centerness는 bce loss를 사용하여 최적화합니다.centernessloss의 최적화 목표가 불안정하기 때문에 실제 훈련할 때loss 초기에 조금 떨어진 후에 장기적으로 더 이상 떨어지지 않는 상황이 나타날 수 있습니다. 이것은 정상적인 것이니 걱정할 필요가 없습니다.centerness loss 코드는 다음과 같습니다.
        def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                               per_image_targets):
            """
            compute one image center_ness loss(center ness loss)
            per_image_center_preds:[points_num,4]
            per_image_targets:[anchor_num,8]
            """
            # only use positive points sample to compute center_ness loss
            device = per_image_center_preds.device
            per_image_center_preds = per_image_center_preds[
                per_image_targets[:, 4] > 0]
            per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
            positive_points_num = per_image_targets.shape[0]
    
            if positive_points_num == 0:
                return torch.tensor(0.).to(device)
    
            center_ness_targets = per_image_targets[:, 5:6]
    
            center_ness_loss = -(
                center_ness_targets * torch.log(per_image_center_preds) +
                (1. - center_ness_targets) *
                torch.log(1. - per_image_center_preds))
            center_ness_loss = center_ness_loss.sum() / positive_points_num
    
            return center_ness_loss
    

    전체 loss 코드

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    INF = 100000000
    
    
    class FCOSLoss(nn.Module):
        def __init__(self,
                     image_w,
                     image_h,
                     strides=[8, 16, 32, 64, 128],
                     mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
                     alpha=0.25,
                     gamma=2.,
                     epsilon=1e-4):
            super(FCOSLoss, self).__init__()
            self.alpha = alpha
            self.gamma = gamma
            self.epsilon = epsilon
            self.image_w = image_w
            self.image_h = image_h
            self.strides = strides
            self.mi = mi
    
        def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
                    annotations):
            """
            compute cls loss, reg loss and center-ness loss in one batch
            """
            cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
                cls_heads, reg_heads, center_heads, batch_positions, annotations)
    
            cls_preds = torch.sigmoid(cls_preds)
            reg_preds = torch.exp(reg_preds)
            center_preds = torch.sigmoid(center_preds)
            batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])
    
            device = annotations.device
            cls_loss, reg_loss, center_ness_loss = [], [], []
            valid_image_num = 0
            for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
                    cls_preds, reg_preds, center_preds, batch_targets):
                positive_points_num = (
                    per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
                if positive_points_num == 0:
                    cls_loss.append(torch.tensor(0.).to(device))
                    reg_loss.append(torch.tensor(0.).to(device))
                    center_ness_loss.append(torch.tensor(0.).to(device))
                else:
                    valid_image_num += 1
                    one_image_cls_loss = self.compute_one_image_focal_loss(
                        per_image_cls_preds, per_image_targets)
                    one_image_reg_loss = self.compute_one_image_giou_loss(
                        per_image_reg_preds, per_image_targets)
                    one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
                        per_image_center_preds, per_image_targets)
    
                    cls_loss.append(one_image_cls_loss)
                    reg_loss.append(one_image_reg_loss)
                    center_ness_loss.append(one_image_center_ness_loss)
    
            cls_loss = sum(cls_loss) / valid_image_num
            reg_loss = sum(reg_loss) / valid_image_num
            center_ness_loss = sum(center_ness_loss) / valid_image_num
    
            return cls_loss, reg_loss, center_ness_loss
    
        def compute_one_image_focal_loss(self, per_image_cls_preds,
                                         per_image_targets):
            """
            compute one image focal loss(cls loss)
            per_image_cls_preds:[points_num,num_classes]
            per_image_targets:[points_num,8]
            """
            per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                              min=self.epsilon,
                                              max=1. - self.epsilon)
            num_classes = per_image_cls_preds.shape[1]
    
            # generate 80 binary ground truth classes for each anchor
            loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                          num_classes=num_classes + 1)
            loss_ground_truth = loss_ground_truth[:, 1:]
            loss_ground_truth = loss_ground_truth.float()
    
            alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
            alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                       alpha_factor, 1. - alpha_factor)
            pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                             1. - per_image_cls_preds)
            focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
    
            bce_loss = -(
                loss_ground_truth * torch.log(per_image_cls_preds) +
                (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
    
            one_image_focal_loss = focal_weight * bce_loss
    
            one_image_focal_loss = one_image_focal_loss.sum()
            positive_points_num = per_image_targets[
                per_image_targets[:, 4] > 0].shape[0]
            # according to the original paper,We divide the focal loss by the number of positive sample anchors
            one_image_focal_loss = one_image_focal_loss / positive_points_num
    
            return one_image_focal_loss
    
        def compute_one_image_giou_loss(self, per_image_reg_preds,
                                        per_image_targets):
            """
            compute one image giou loss(reg loss)
            per_image_reg_preds:[points_num,4]
            per_image_targets:[anchor_num,8]
            """
            # only use positive points sample to compute reg loss
            device = per_image_reg_preds.device
            per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
            per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
            positive_points_num = per_image_targets.shape[0]
    
            if positive_points_num == 0:
                return torch.tensor(0.).to(device)
    
            center_ness_targets = per_image_targets[:, 5]
    
            pred_bboxes_xy_min = per_image_targets[:,
                                                   6:8] - per_image_reg_preds[:,
                                                                              0:2]
            pred_bboxes_xy_max = per_image_targets[:,
                                                   6:8] + per_image_reg_preds[:,
                                                                              2:4]
            gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                             0:2]
            gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                             2:4]
    
            pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                    axis=1)
            gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
    
            overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                             0:2])
            overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                              2:4])
            overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                             overlap_area_top_left,
                                             min=0)
            overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
    
            # anchors and annotations convert format to [x1,y1,w,h]
            pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
            gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
    
            # compute anchors_area and annotations_area
            pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
            gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
    
            # compute union_area
            union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
            union_area = torch.clamp(union_area, min=1e-4)
            # compute ious between one image anchors and one image annotations
            ious = overlap_area / union_area
    
            enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                             0:2])
            enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                              2:4])
            enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                             enclose_area_top_left,
                                             min=0)
            enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
            enclose_area = torch.clamp(enclose_area, min=1e-4)
    
            gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
            gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
            # use center_ness_targets as the weight of gious loss
            gious_loss = gious_loss * center_ness_targets
            gious_loss = gious_loss.sum() / positive_points_num
            gious_loss = 2. * gious_loss
    
            return gious_loss
    
        def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                               per_image_targets):
            """
            compute one image center_ness loss(center ness loss)
            per_image_center_preds:[points_num,4]
            per_image_targets:[anchor_num,8]
            """
            # only use positive points sample to compute center_ness loss
            device = per_image_center_preds.device
            per_image_center_preds = per_image_center_preds[
                per_image_targets[:, 4] > 0]
            per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
            positive_points_num = per_image_targets.shape[0]
    
            if positive_points_num == 0:
                return torch.tensor(0.).to(device)
    
            center_ness_targets = per_image_targets[:, 5:6]
    
            center_ness_loss = -(
                center_ness_targets * torch.log(per_image_center_preds) +
                (1. - center_ness_targets) *
                torch.log(1. - per_image_center_preds))
            center_ness_loss = center_ness_loss.sum() / positive_points_num
    
            return center_ness_loss
    
        def get_batch_position_annotations(self, cls_heads, reg_heads,
                                           center_heads, batch_positions,
                                           annotations):
            """
            Assign a ground truth target for each position on feature map
            """
            device = annotations.device
            batch_mi = []
            for reg_head, mi in zip(reg_heads, self.mi):
                mi = torch.tensor(mi).to(device)
                B, H, W, _ = reg_head.shape
                per_level_mi = torch.zeros(B, H, W, 2).to(device)
                per_level_mi = per_level_mi + mi
                batch_mi.append(per_level_mi)
    
            cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
            for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                    cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
                cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
                reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
                center_pred = center_pred.view(center_pred.shape[0], -1,
                                               center_pred.shape[-1])
                per_level_position = per_level_position.view(
                    per_level_position.shape[0], -1, per_level_position.shape[-1])
                per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                                 per_level_mi.shape[-1])
    
                cls_preds.append(cls_pred)
                reg_preds.append(reg_pred)
                center_preds.append(center_pred)
                all_points_position.append(per_level_position)
                all_points_mi.append(per_level_mi)
    
            cls_preds = torch.cat(cls_preds, axis=1)
            reg_preds = torch.cat(reg_preds, axis=1)
            center_preds = torch.cat(center_preds, axis=1)
            all_points_position = torch.cat(all_points_position, axis=1)
            all_points_mi = torch.cat(all_points_mi, axis=1)
    
            batch_targets = []
            for per_image_position, per_image_mi, per_image_annotations in zip(
                    all_points_position, all_points_mi, annotations):
                per_image_annotations = per_image_annotations[
                    per_image_annotations[:, 4] >= 0]
                points_num = per_image_position.shape[0]
    
                if per_image_annotations.shape[0] == 0:
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6], device=device)
                else:
                    annotaion_num = per_image_annotations.shape[0]
                    per_image_gt_bboxes = per_image_annotations[:, 0:4]
                    candidates = torch.zeros([points_num, annotaion_num, 4],
                                             device=device)
                    candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                    per_image_position = per_image_position.unsqueeze(1).repeat(
                        1, annotaion_num, 2)
                    candidates[:, :,
                               0:2] = per_image_position[:, :,
                                                         0:2] - candidates[:, :,
                                                                           0:2]
                    candidates[:, :,
                               2:4] = candidates[:, :,
                                                 2:4] - per_image_position[:, :,
                                                                           2:4]
    
                    candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                    sample_flag = (candidates_min_value[:, :, 0] >
                                   0).int().unsqueeze(-1)
                    # get all negative reg targets which points ctr out of gt box
                    candidates = candidates * sample_flag
    
                    # get all negative reg targets which assign ground turth not in range of mi
                    candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                    per_image_mi = per_image_mi.unsqueeze(1).repeat(
                        1, annotaion_num, 1)
                    m1_negative_flag = (candidates_max_value[:, :, 0] >
                                        per_image_mi[:, :, 0]).int().unsqueeze(-1)
                    candidates = candidates * m1_negative_flag
                    m2_negative_flag = (candidates_max_value[:, :, 0] <
                                        per_image_mi[:, :, 1]).int().unsqueeze(-1)
                    candidates = candidates * m2_negative_flag
    
                    final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                    final_sample_flag = final_sample_flag > 0
                    positive_index = (final_sample_flag == True).nonzero().squeeze(
                        dim=-1)
    
                    # if no assign positive sample
                    if len(positive_index) == 0:
                        del candidates
                        # 6:l,t,r,b,class_index,center-ness_gt
                        per_image_targets = torch.zeros([points_num, 6],
                                                        device=device)
                    else:
                        positive_candidates = candidates[positive_index]
    
                        del candidates
    
                        sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                        sample_box_gts = sample_box_gts.repeat(
                            positive_candidates.shape[0], 1, 1)
                        sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                            -1).unsqueeze(0)
                        sample_class_gts = sample_class_gts.repeat(
                            positive_candidates.shape[0], 1, 1)
    
                        # 6:l,t,r,b,class_index,center-ness_gt
                        per_image_targets = torch.zeros([points_num, 6],
                                                        device=device)
    
                        if positive_candidates.shape[1] == 1:
                            # if only one candidate for each positive sample
                            # assign l,t,r,b,class_index,center_ness_gt ground truth
                            # class_index value from 1 to 80 represent 80 positive classes
                            # class_index value 0 represenet negative class
                            positive_candidates = positive_candidates.squeeze(1)
                            sample_class_gts = sample_class_gts.squeeze(1)
                            per_image_targets[positive_index,
                                              0:4] = positive_candidates
                            per_image_targets[positive_index,
                                              4:5] = sample_class_gts + 1
    
                            l, t, r, b = per_image_targets[
                                positive_index, 0:1], per_image_targets[
                                    positive_index, 1:2], per_image_targets[
                                        positive_index,
                                        2:3], per_image_targets[positive_index,
                                                                3:4]
                            per_image_targets[positive_index, 5:6] = torch.sqrt(
                                (torch.min(l, r) / torch.max(l, r)) *
                                (torch.min(t, b) / torch.max(t, b)))
                        else:
                            # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                            gts_w_h = sample_box_gts[:, :,
                                                     2:4] - sample_box_gts[:, :,
                                                                           0:2]
                            gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                            positive_candidates_value = positive_candidates.sum(
                                axis=2)
    
                            # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                            INF = 100000000
                            inf_tensor = torch.ones_like(gts_area) * INF
                            gts_area = torch.where(
                                torch.eq(positive_candidates_value, 0.),
                                inf_tensor, gts_area)
    
                            # get the smallest object candidate index
                            _, min_index = gts_area.min(axis=1)
                            candidate_indexes = (
                                torch.linspace(1, positive_candidates.shape[0],
                                               positive_candidates.shape[0]) -
                                1).long()
                            final_candidate_reg_gts = positive_candidates[
                                candidate_indexes, min_index, :]
                            final_candidate_cls_gts = sample_class_gts[
                                candidate_indexes, min_index]
    
                            # assign l,t,r,b,class_index,center_ness_gt ground truth
                            per_image_targets[positive_index,
                                              0:4] = final_candidate_reg_gts
                            per_image_targets[positive_index,
                                              4:5] = final_candidate_cls_gts + 1
    
                            l, t, r, b = per_image_targets[
                                positive_index, 0:1], per_image_targets[
                                    positive_index, 1:2], per_image_targets[
                                        positive_index,
                                        2:3], per_image_targets[positive_index,
                                                                3:4]
                            per_image_targets[positive_index, 5:6] = torch.sqrt(
                                (torch.min(l, r) / torch.max(l, r)) *
                                (torch.min(t, b) / torch.max(t, b)))
    
                per_image_targets = per_image_targets.unsqueeze(0)
                batch_targets.append(per_image_targets)
    
            batch_targets = torch.cat(batch_targets, axis=0)
            batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
    
            # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
            return cls_preds, reg_preds, center_preds, batch_targets
    
    
    if __name__ == '__main__':
        from fcos import FCOS
        net = FCOS(resnet_type="resnet50")
        image_h, image_w = 600, 600
        cls_heads, reg_heads, center_heads, batch_positions = net(
            torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
        annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                          [13, 45, 175, 210, 2]],
                                         [[11, 18, 223, 225, 1],
                                          [-1, -1, -1, -1, -1]],
                                         [[-1, -1, -1, -1, -1],
                                          [-1, -1, -1, -1, -1]]])
        loss = FCOSLoss(image_w, image_h)
        cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
                                               batch_positions, annotations)
        print("2222", cls_loss, reg_loss, center_loss)
    

    좋은 웹페이지 즐겨찾기