[주방장해우] 제로에서 FCOS 구현(二):ground truth 분배와loss 계산
문서 목록
Anchor free?Anchor base?  
우선 FCOS가 RetinaNet처럼 현식 Anchor를 사용하지 않았다는 것을 명확히 해야 한다.FCOS는 각 FPN level의feature map에 있는 모든 점을 샘플로 한 다음에 샘플이 표시 상자 안이나 표시 상자 밖에 있는지 여부에 따라 이 샘플이 정샘플인지 음샘플인지 결정한다(FCOS에서 무시되지 않은 샘플이 없음을 주의한다).그런 점에서 FCOS는 확실히 Anchor free의 것이다.그러나 FCOS에서ground trurh 분배와 테스트 계산을 할 때feature map의 각 점을 입력 그림(x, y) 좌표로 거꾸로 밀어야 한다는 점에서 FCOS는 완전한free가 아니다. 더욱 정확히 말하면 FCOS는'point based'목표 탐지기이다.우리는 FCOS를 피처 맵에 있는 점마다 은밀한 Anchor의 목표 탐지기만 있는 것으로 볼 수 있다.
2020년에 새로 출시된 DETR 대상 감지기(https://arxiv.org/pdf/2005.12872.pdf) 목표 검측 임무 검측을 집합 예측 문제로 보고 Transformer를 사용하여 box 집합을 예측했다. NMS와 Anchor/Point의 선험 좌표를 전혀 사용하지 않아도 검측기가 진정으로'free'를 할 수 있고 관심 있는 학생들이 스스로 이해할 수 있다.
FCOS에 대한 ground truth 할당  
입력 그림에 표시된 여러 개의 상자에 대해 먼저 FPN의 각 단계의 FPN의feature map에 표시된 모든 점을 판단한다. 만약에 어떤 점이 모든 표시 상자 밖에 있다면 이 점은 마이너스 샘플이 된다.이때 나머지 점 중 일부는 여러 치수 상자 안에 동시에 있을 수 있습니다.그리고 각 점의 각 표시 상자에 대한 l, t, r, b(이 점의 거리 상자는 왼쪽, 위, 오른쪽, 아래의 거리)의 최대 값을 취하고 아래의 값 영역 범위에 따라 최대 값이 어느 범위 안에 떨어지면 이 범위에 대응하는 FPN level의feature map의 대응점에 이 상자를 분배한다.#         P3、P4、P5、P6、P7     
INF=100000000
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]
  
위의 단계를 거친 후, 대부분의 점은 단지 하나의 틀에만 분배될 것이다.그러나 두 치수 상자의 크기 차이가 많지 않을 때 두 상자 안에 있는 점이 있습니다.이러한 점에 대해 우리는 중첩 상자의 면적을 계산한 후에 항상 이 점들을 면적이 가장 작은 치수 상자에 분배한다.아래의 실현 코드에서 이 부분의 견본에 대해 나는 행렬 계산 형식으로 라벨 분배를 했다.한 장의 그림에 있는 샘플은 일반적으로 수십 ~ 이삼백 정도에 불과하지만 이 부분의 샘플에 대해 for순환을 사용하여 라벨을 분배하면 훈련 속도가 매우 느려질 수 있으므로 주의해야 한다.분류 라벨에 대해 0을 마이너스 샘플로 하고 1부터 80까지는 80개의 정류로 한다.l, t, r, b와centerness 라벨은 FCOS 논문에서 공식적으로 계산된 것으로 수정되지 않았다.
ground truth 할당 코드는 다음과 같습니다.    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)
        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])
            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)
        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)
        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]
            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]
                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag
                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag
                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)
                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]
                    del candidates
                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)
                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)
                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)
        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets
  
loss 계산  
분류loss는 focal loss를 사용하는데 계산 과정은 RetinaNet과 완전히 같다. 단지 견본이 Anchor에서 Point로 바뀌었을 뿐이다.
분류 loss 코드는 다음과 같습니다.    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
  
FCOS 논문에서 회귀loss는 IoU loss를 채택했다.여기는 제가 직접 GIOU loss를 사용합니다.회귀loss는 여전히 정본만 계산하기 때문에 예측상자와 진실상자가 교차하지 않는 상황이 존재하지 않는다. 이때 GIoU loss와 IoU loss는 완전히 동일하다.
회귀loss 코드는 다음과 같습니다.    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
  
마지막으로 2를 곱하는 것은 회귀loss와 기타loss의 수량급을 균형 있게 하기 위한 것이다.
centerness는 bce loss를 사용하여 최적화합니다.centernessloss의 최적화 목표가 불안정하기 때문에 실제 훈련할 때loss 초기에 조금 떨어진 후에 장기적으로 더 이상 떨어지지 않는 상황이 나타날 수 있습니다. 이것은 정상적인 것이니 걱정할 필요가 없습니다.centerness loss 코드는 다음과 같습니다.    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
  
전체 loss 코드  import torch
import torch.nn as nn
import torch.nn.functional as F
INF = 100000000
class FCOSLoss(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 strides=[8, 16, 32, 64, 128],
                 mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
                 alpha=0.25,
                 gamma=2.,
                 epsilon=1e-4):
        super(FCOSLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.image_w = image_w
        self.image_h = image_h
        self.strides = strides
        self.mi = mi
    def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
                annotations):
        """
        compute cls loss, reg loss and center-ness loss in one batch
        """
        cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
            cls_heads, reg_heads, center_heads, batch_positions, annotations)
        cls_preds = torch.sigmoid(cls_preds)
        reg_preds = torch.exp(reg_preds)
        center_preds = torch.sigmoid(center_preds)
        batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])
        device = annotations.device
        cls_loss, reg_loss, center_ness_loss = [], [], []
        valid_image_num = 0
        for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
                cls_preds, reg_preds, center_preds, batch_targets):
            positive_points_num = (
                per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
            if positive_points_num == 0:
                cls_loss.append(torch.tensor(0.).to(device))
                reg_loss.append(torch.tensor(0.).to(device))
                center_ness_loss.append(torch.tensor(0.).to(device))
            else:
                valid_image_num += 1
                one_image_cls_loss = self.compute_one_image_focal_loss(
                    per_image_cls_preds, per_image_targets)
                one_image_reg_loss = self.compute_one_image_giou_loss(
                    per_image_reg_preds, per_image_targets)
                one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
                    per_image_center_preds, per_image_targets)
                cls_loss.append(one_image_cls_loss)
                reg_loss.append(one_image_reg_loss)
                center_ness_loss.append(one_image_center_ness_loss)
        cls_loss = sum(cls_loss) / valid_image_num
        reg_loss = sum(reg_loss) / valid_image_num
        center_ness_loss = sum(center_ness_loss) / valid_image_num
        return cls_loss, reg_loss, center_ness_loss
    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)
        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])
            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)
        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)
        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]
            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]
                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag
                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag
                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)
                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]
                    del candidates
                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)
                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)
                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)
        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets
if __name__ == '__main__':
    from fcos import FCOS
    net = FCOS(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, center_heads, batch_positions = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    loss = FCOSLoss(image_w, image_h)
    cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
                                           batch_positions, annotations)
    print("2222", cls_loss, reg_loss, center_loss)
                
                    
        
    
    
    
    
    
                
                
                
                
                
                
                    
                        
                            
                            
                                
                                    
                                    이 내용에 흥미가 있습니까?
                                
                            
                            
                            
                            현재 기사가 여러분의 문제를 해결하지 못하는 경우  AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
                            
                                
                                다양한 언어의 JSON
                            
                            JSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다.
그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다.
저는 일반적으로 '객체'{}...
                            
                            
                            
                            
                            텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
                            
                            CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.
                            
                        
                    
                
                
                
            
입력 그림에 표시된 여러 개의 상자에 대해 먼저 FPN의 각 단계의 FPN의feature map에 표시된 모든 점을 판단한다. 만약에 어떤 점이 모든 표시 상자 밖에 있다면 이 점은 마이너스 샘플이 된다.이때 나머지 점 중 일부는 여러 치수 상자 안에 동시에 있을 수 있습니다.그리고 각 점의 각 표시 상자에 대한 l, t, r, b(이 점의 거리 상자는 왼쪽, 위, 오른쪽, 아래의 거리)의 최대 값을 취하고 아래의 값 영역 범위에 따라 최대 값이 어느 범위 안에 떨어지면 이 범위에 대응하는 FPN level의feature map의 대응점에 이 상자를 분배한다.
#         P3、P4、P5、P6、P7     
INF=100000000
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]
위의 단계를 거친 후, 대부분의 점은 단지 하나의 틀에만 분배될 것이다.그러나 두 치수 상자의 크기 차이가 많지 않을 때 두 상자 안에 있는 점이 있습니다.이러한 점에 대해 우리는 중첩 상자의 면적을 계산한 후에 항상 이 점들을 면적이 가장 작은 치수 상자에 분배한다.아래의 실현 코드에서 이 부분의 견본에 대해 나는 행렬 계산 형식으로 라벨 분배를 했다.한 장의 그림에 있는 샘플은 일반적으로 수십 ~ 이삼백 정도에 불과하지만 이 부분의 샘플에 대해 for순환을 사용하여 라벨을 분배하면 훈련 속도가 매우 느려질 수 있으므로 주의해야 한다.분류 라벨에 대해 0을 마이너스 샘플로 하고 1부터 80까지는 80개의 정류로 한다.l, t, r, b와centerness 라벨은 FCOS 논문에서 공식적으로 계산된 것으로 수정되지 않았다.
ground truth 할당 코드는 다음과 같습니다.
    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)
        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])
            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)
        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)
        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]
            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]
                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag
                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag
                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)
                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]
                    del candidates
                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)
                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)
                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)
        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets
loss 계산  
분류loss는 focal loss를 사용하는데 계산 과정은 RetinaNet과 완전히 같다. 단지 견본이 Anchor에서 Point로 바뀌었을 뿐이다.
분류 loss 코드는 다음과 같습니다.    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
  
FCOS 논문에서 회귀loss는 IoU loss를 채택했다.여기는 제가 직접 GIOU loss를 사용합니다.회귀loss는 여전히 정본만 계산하기 때문에 예측상자와 진실상자가 교차하지 않는 상황이 존재하지 않는다. 이때 GIoU loss와 IoU loss는 완전히 동일하다.
회귀loss 코드는 다음과 같습니다.    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
  
마지막으로 2를 곱하는 것은 회귀loss와 기타loss의 수량급을 균형 있게 하기 위한 것이다.
centerness는 bce loss를 사용하여 최적화합니다.centernessloss의 최적화 목표가 불안정하기 때문에 실제 훈련할 때loss 초기에 조금 떨어진 후에 장기적으로 더 이상 떨어지지 않는 상황이 나타날 수 있습니다. 이것은 정상적인 것이니 걱정할 필요가 없습니다.centerness loss 코드는 다음과 같습니다.    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
  
전체 loss 코드  import torch
import torch.nn as nn
import torch.nn.functional as F
INF = 100000000
class FCOSLoss(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 strides=[8, 16, 32, 64, 128],
                 mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
                 alpha=0.25,
                 gamma=2.,
                 epsilon=1e-4):
        super(FCOSLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.image_w = image_w
        self.image_h = image_h
        self.strides = strides
        self.mi = mi
    def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
                annotations):
        """
        compute cls loss, reg loss and center-ness loss in one batch
        """
        cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
            cls_heads, reg_heads, center_heads, batch_positions, annotations)
        cls_preds = torch.sigmoid(cls_preds)
        reg_preds = torch.exp(reg_preds)
        center_preds = torch.sigmoid(center_preds)
        batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])
        device = annotations.device
        cls_loss, reg_loss, center_ness_loss = [], [], []
        valid_image_num = 0
        for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
                cls_preds, reg_preds, center_preds, batch_targets):
            positive_points_num = (
                per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
            if positive_points_num == 0:
                cls_loss.append(torch.tensor(0.).to(device))
                reg_loss.append(torch.tensor(0.).to(device))
                center_ness_loss.append(torch.tensor(0.).to(device))
            else:
                valid_image_num += 1
                one_image_cls_loss = self.compute_one_image_focal_loss(
                    per_image_cls_preds, per_image_targets)
                one_image_reg_loss = self.compute_one_image_giou_loss(
                    per_image_reg_preds, per_image_targets)
                one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
                    per_image_center_preds, per_image_targets)
                cls_loss.append(one_image_cls_loss)
                reg_loss.append(one_image_reg_loss)
                center_ness_loss.append(one_image_center_ness_loss)
        cls_loss = sum(cls_loss) / valid_image_num
        reg_loss = sum(reg_loss) / valid_image_num
        center_ness_loss = sum(center_ness_loss) / valid_image_num
        return cls_loss, reg_loss, center_ness_loss
    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)
        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])
            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)
        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)
        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]
            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]
                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag
                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag
                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)
                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]
                    del candidates
                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)
                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)
                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)
        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets
if __name__ == '__main__':
    from fcos import FCOS
    net = FCOS(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, center_heads, batch_positions = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    loss = FCOSLoss(image_w, image_h)
    cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
                                           batch_positions, annotations)
    print("2222", cls_loss, reg_loss, center_loss)
                
                    
        
    
    
    
    
    
                
                
                
                
                
                
                    
                        
                            
                            
                                
                                    
                                    이 내용에 흥미가 있습니까?
                                
                            
                            
                            
                            현재 기사가 여러분의 문제를 해결하지 못하는 경우  AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
                            
                                
                                다양한 언어의 JSON
                            
                            JSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다.
그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다.
저는 일반적으로 '객체'{}...
                            
                            
                            
                            
                            텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
                            
                            CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.
                            
                        
                    
                
                
                
            
    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
INF = 100000000
class FCOSLoss(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 strides=[8, 16, 32, 64, 128],
                 mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
                 alpha=0.25,
                 gamma=2.,
                 epsilon=1e-4):
        super(FCOSLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.image_w = image_w
        self.image_h = image_h
        self.strides = strides
        self.mi = mi
    def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
                annotations):
        """
        compute cls loss, reg loss and center-ness loss in one batch
        """
        cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
            cls_heads, reg_heads, center_heads, batch_positions, annotations)
        cls_preds = torch.sigmoid(cls_preds)
        reg_preds = torch.exp(reg_preds)
        center_preds = torch.sigmoid(center_preds)
        batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])
        device = annotations.device
        cls_loss, reg_loss, center_ness_loss = [], [], []
        valid_image_num = 0
        for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
                cls_preds, reg_preds, center_preds, batch_targets):
            positive_points_num = (
                per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
            if positive_points_num == 0:
                cls_loss.append(torch.tensor(0.).to(device))
                reg_loss.append(torch.tensor(0.).to(device))
                center_ness_loss.append(torch.tensor(0.).to(device))
            else:
                valid_image_num += 1
                one_image_cls_loss = self.compute_one_image_focal_loss(
                    per_image_cls_preds, per_image_targets)
                one_image_reg_loss = self.compute_one_image_giou_loss(
                    per_image_reg_preds, per_image_targets)
                one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
                    per_image_center_preds, per_image_targets)
                cls_loss.append(one_image_cls_loss)
                reg_loss.append(one_image_reg_loss)
                center_ness_loss.append(one_image_center_ness_loss)
        cls_loss = sum(cls_loss) / valid_image_num
        reg_loss = sum(reg_loss) / valid_image_num
        center_ness_loss = sum(center_ness_loss) / valid_image_num
        return cls_loss, reg_loss, center_ness_loss
    def compute_one_image_focal_loss(self, per_image_cls_preds,
                                     per_image_targets):
        """
        compute one image focal loss(cls loss)
        per_image_cls_preds:[points_num,num_classes]
        per_image_targets:[points_num,8]
        """
        per_image_cls_preds = torch.clamp(per_image_cls_preds,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_preds.shape[1]
        # generate 80 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
                                      num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float()
        alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
                                   alpha_factor, 1. - alpha_factor)
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
                         1. - per_image_cls_preds)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
        bce_loss = -(
            loss_ground_truth * torch.log(per_image_cls_preds) +
            (1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
        one_image_focal_loss = focal_weight * bce_loss
        one_image_focal_loss = one_image_focal_loss.sum()
        positive_points_num = per_image_targets[
            per_image_targets[:, 4] > 0].shape[0]
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_points_num
        return one_image_focal_loss
    def compute_one_image_giou_loss(self, per_image_reg_preds,
                                    per_image_targets):
        """
        compute one image giou loss(reg loss)
        per_image_reg_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute reg loss
        device = per_image_reg_preds.device
        per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5]
        pred_bboxes_xy_min = per_image_targets[:,
                                               6:8] - per_image_reg_preds[:,
                                                                          0:2]
        pred_bboxes_xy_max = per_image_targets[:,
                                               6:8] + per_image_reg_preds[:,
                                                                          2:4]
        gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
                                                                         0:2]
        gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
                                                                         2:4]
        pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
                                axis=1)
        gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
        # compute union_area
        union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area
        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
                                                                         0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
                                                                          2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)
        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
        # use center_ness_targets as the weight of gious loss
        gious_loss = gious_loss * center_ness_targets
        gious_loss = gious_loss.sum() / positive_points_num
        gious_loss = 2. * gious_loss
        return gious_loss
    def compute_one_image_center_ness_loss(self, per_image_center_preds,
                                           per_image_targets):
        """
        compute one image center_ness loss(center ness loss)
        per_image_center_preds:[points_num,4]
        per_image_targets:[anchor_num,8]
        """
        # only use positive points sample to compute center_ness loss
        device = per_image_center_preds.device
        per_image_center_preds = per_image_center_preds[
            per_image_targets[:, 4] > 0]
        per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
        positive_points_num = per_image_targets.shape[0]
        if positive_points_num == 0:
            return torch.tensor(0.).to(device)
        center_ness_targets = per_image_targets[:, 5:6]
        center_ness_loss = -(
            center_ness_targets * torch.log(per_image_center_preds) +
            (1. - center_ness_targets) *
            torch.log(1. - per_image_center_preds))
        center_ness_loss = center_ness_loss.sum() / positive_points_num
        return center_ness_loss
    def get_batch_position_annotations(self, cls_heads, reg_heads,
                                       center_heads, batch_positions,
                                       annotations):
        """
        Assign a ground truth target for each position on feature map
        """
        device = annotations.device
        batch_mi = []
        for reg_head, mi in zip(reg_heads, self.mi):
            mi = torch.tensor(mi).to(device)
            B, H, W, _ = reg_head.shape
            per_level_mi = torch.zeros(B, H, W, 2).to(device)
            per_level_mi = per_level_mi + mi
            batch_mi.append(per_level_mi)
        cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
        for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
                cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
            cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
            reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
            center_pred = center_pred.view(center_pred.shape[0], -1,
                                           center_pred.shape[-1])
            per_level_position = per_level_position.view(
                per_level_position.shape[0], -1, per_level_position.shape[-1])
            per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
                                             per_level_mi.shape[-1])
            cls_preds.append(cls_pred)
            reg_preds.append(reg_pred)
            center_preds.append(center_pred)
            all_points_position.append(per_level_position)
            all_points_mi.append(per_level_mi)
        cls_preds = torch.cat(cls_preds, axis=1)
        reg_preds = torch.cat(reg_preds, axis=1)
        center_preds = torch.cat(center_preds, axis=1)
        all_points_position = torch.cat(all_points_position, axis=1)
        all_points_mi = torch.cat(all_points_mi, axis=1)
        batch_targets = []
        for per_image_position, per_image_mi, per_image_annotations in zip(
                all_points_position, all_points_mi, annotations):
            per_image_annotations = per_image_annotations[
                per_image_annotations[:, 4] >= 0]
            points_num = per_image_position.shape[0]
            if per_image_annotations.shape[0] == 0:
                # 6:l,t,r,b,class_index,center-ness_gt
                per_image_targets = torch.zeros([points_num, 6], device=device)
            else:
                annotaion_num = per_image_annotations.shape[0]
                per_image_gt_bboxes = per_image_annotations[:, 0:4]
                candidates = torch.zeros([points_num, annotaion_num, 4],
                                         device=device)
                candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
                per_image_position = per_image_position.unsqueeze(1).repeat(
                    1, annotaion_num, 2)
                candidates[:, :,
                           0:2] = per_image_position[:, :,
                                                     0:2] - candidates[:, :,
                                                                       0:2]
                candidates[:, :,
                           2:4] = candidates[:, :,
                                             2:4] - per_image_position[:, :,
                                                                       2:4]
                candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
                sample_flag = (candidates_min_value[:, :, 0] >
                               0).int().unsqueeze(-1)
                # get all negative reg targets which points ctr out of gt box
                candidates = candidates * sample_flag
                # get all negative reg targets which assign ground turth not in range of mi
                candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
                per_image_mi = per_image_mi.unsqueeze(1).repeat(
                    1, annotaion_num, 1)
                m1_negative_flag = (candidates_max_value[:, :, 0] >
                                    per_image_mi[:, :, 0]).int().unsqueeze(-1)
                candidates = candidates * m1_negative_flag
                m2_negative_flag = (candidates_max_value[:, :, 0] <
                                    per_image_mi[:, :, 1]).int().unsqueeze(-1)
                candidates = candidates * m2_negative_flag
                final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
                final_sample_flag = final_sample_flag > 0
                positive_index = (final_sample_flag == True).nonzero().squeeze(
                    dim=-1)
                # if no assign positive sample
                if len(positive_index) == 0:
                    del candidates
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                else:
                    positive_candidates = candidates[positive_index]
                    del candidates
                    sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
                    sample_box_gts = sample_box_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    sample_class_gts = per_image_annotations[:, 4].unsqueeze(
                        -1).unsqueeze(0)
                    sample_class_gts = sample_class_gts.repeat(
                        positive_candidates.shape[0], 1, 1)
                    # 6:l,t,r,b,class_index,center-ness_gt
                    per_image_targets = torch.zeros([points_num, 6],
                                                    device=device)
                    if positive_candidates.shape[1] == 1:
                        # if only one candidate for each positive sample
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        # class_index value from 1 to 80 represent 80 positive classes
                        # class_index value 0 represenet negative class
                        positive_candidates = positive_candidates.squeeze(1)
                        sample_class_gts = sample_class_gts.squeeze(1)
                        per_image_targets[positive_index,
                                          0:4] = positive_candidates
                        per_image_targets[positive_index,
                                          4:5] = sample_class_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
                    else:
                        # if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
                        gts_w_h = sample_box_gts[:, :,
                                                 2:4] - sample_box_gts[:, :,
                                                                       0:2]
                        gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
                        positive_candidates_value = positive_candidates.sum(
                            axis=2)
                        # make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
                        INF = 100000000
                        inf_tensor = torch.ones_like(gts_area) * INF
                        gts_area = torch.where(
                            torch.eq(positive_candidates_value, 0.),
                            inf_tensor, gts_area)
                        # get the smallest object candidate index
                        _, min_index = gts_area.min(axis=1)
                        candidate_indexes = (
                            torch.linspace(1, positive_candidates.shape[0],
                                           positive_candidates.shape[0]) -
                            1).long()
                        final_candidate_reg_gts = positive_candidates[
                            candidate_indexes, min_index, :]
                        final_candidate_cls_gts = sample_class_gts[
                            candidate_indexes, min_index]
                        # assign l,t,r,b,class_index,center_ness_gt ground truth
                        per_image_targets[positive_index,
                                          0:4] = final_candidate_reg_gts
                        per_image_targets[positive_index,
                                          4:5] = final_candidate_cls_gts + 1
                        l, t, r, b = per_image_targets[
                            positive_index, 0:1], per_image_targets[
                                positive_index, 1:2], per_image_targets[
                                    positive_index,
                                    2:3], per_image_targets[positive_index,
                                                            3:4]
                        per_image_targets[positive_index, 5:6] = torch.sqrt(
                            (torch.min(l, r) / torch.max(l, r)) *
                            (torch.min(t, b) / torch.max(t, b)))
            per_image_targets = per_image_targets.unsqueeze(0)
            batch_targets.append(per_image_targets)
        batch_targets = torch.cat(batch_targets, axis=0)
        batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
        # batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
        return cls_preds, reg_preds, center_preds, batch_targets
if __name__ == '__main__':
    from fcos import FCOS
    net = FCOS(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, center_heads, batch_positions = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    loss = FCOSLoss(image_w, image_h)
    cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
                                           batch_positions, annotations)
    print("2222", cls_loss, reg_loss, center_loss)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
다양한 언어의 JSONJSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다. 그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다. 저는 일반적으로 '객체'{}...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.