SSD의 하드 앤 드롭 결과를 시각화해 보십시오.
"SSD의 정DBox를 보여 줍니다."코드의 재등재는 다음과 같다. jaccard 계수가 가장 큰 포위함과 기본 상자의 위치 정보
loc_t
에서 이 포위함에 대응하는 라벨 정보와 conf_t
의 기본 상자의 인덱스를 일치시킨다.loc_t
는 함수 정의의 관계에 match
를 정의했지만 이 변수는 이번에 사용하지 않습니다.image_size = 300
batch_size = 1
dataset = VOCDetection(root = '/path/to/root',
transform=BaseTransform(image_size, MEANS))
data_loader = data.DataLoader(dataset, batch_size = batch_size, num_workers=0, shuffle=False, collate_fn = detection_collate)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
images, targets = next(iter(data_loader))
images = images.cuda()
targets = [ann.cuda() for ann in targets]
priors = PriorBox(voc).forward()
num_priors = priors.size(0)
loc_t = torch.Tensor(batch_size, num_priors, priors.size(1))
conf_t = torch.LongTensor(batch_size, num_priors)
match(threshold = 0.5,
truths = targets[0][:, :-1].data,
priors = priors.data,
variances = [0.1, 0.2],
labels = targets[0][:, -1].data,
loc_t = loc_t,
conf_t = conf_t,
idx = 0)
표시 학습에서는 배경을 직관적으로 판단하기 위해 부정적 기본 상자에 대한 손실 정보를 가져올 필요도 있지만, SSD에서는 부정적 기본 상자에 잘 표시되지 않는 내용만 선택한다.이 작업을 하드웨어라고 합니다.단단한 부정적 발굴을 하려면 모델에 표시를 해야 하기 때문에 훈련 모드에서 SSD를 생성해 이미지의 표시
conf_data
를 계산하도록 한다.net = build_ssd('train', 300, len(VOC_CLASSES) + 1)
# init.pthはtrain.pyの学習開始時の重み
net.load_state_dict(torch.load("/path/to/init.pth"))
net.vgg.load_state_dict(torch.load(f"{os.getcwd()}/weights/vgg16_reducedfc.pth"))
net = net.cuda()
net.train()
_, conf_data, _ = net(images)
conf_data
계산에 사용된 SSD의 무게는 원칙적으로 어떤 무게든 가능하지만 여기서 학습을 시작할 때의 무게를 사용합니다.다음 코드는 이른바 하드웨어 소극적 발굴로 표시된 교사 데이터
conf_t
에 대해 교차엔트로피 평가모델의 예측conf_data
으로 얼마나 편차가 있는지 정렬하여 손실이 큰 위치를 얻는다.#ポジティブデフォルトボックスに対して選ぶネガティブデフォルトボックスの個数が何倍あるか
#学習処理では3に設定されていましたが、最後画像として描画する都合のために2にしています
dbox_ratio = 2
conf_t = Variable(conf_t.cuda())
pos = conf_t > 1
batch_conf = conf_data.view(-1, 21)
loss_c = F.cross_entropy(batch_conf, conf_t.view(-1), reduction='none')
num_pos = pos.long().sum(1, keepdim = True)
loss_c = loss_c.view(1, -1)
loss_c[pos] = 0
_, loss_idx = loss_c.sort(1, descending = True)
_, idx_rank = loss_idx.sort(1)
num_neg = torch.clamp(dbox_ratio * num_pos, max = pos.size(1) - 1)
neg = idx_rank < num_neg.expand_as(idx_rank)
기본 상자에서 neg
인덱스 요소의 소극적인 기본 상자를 선택합니다.다음 코드로 이 그림들을 표시합니다.image = (images[0].to('cpu').detach().numpy().transpose(1, 2, 0) + (MEANS[2], MEANS[1], MEANS[0])).astype(np.uint8).copy()
image = cv2.resize(image, (image_size, image_size))
plt.figure(figsize=(80, 180))
indices = [i for i, v in enumerate(list(neg.to('cpu').detach().numpy().copy()[0])) if v]
for i, idx in enumerate(indices):
img = image.copy()
cx_d, cy_d, w_d, h_d = priors[idx].to('cpu').detach().numpy().copy()
xmin_d = int((cx_d - w_d / 2) * image_size)
ymin_d = int((cy_d - h_d / 2) * image_size)
xmax_d = int((cx_d + w_d / 2) * image_size)
ymax_d = int((cy_d + h_d / 2) * image_size)
pt1_d = (xmin_d, ymin_d)
pt2_d = (xmax_d, ymax_d)
cv2.rectangle(img, pt1=pt1_d, pt2=pt2_d, color=(0, 255, 0), thickness=2)
for box in targets[0]:
label = int(box[4])
pt1 = (int(box[0] * image_size), int(box[1] * image_size) )
pt2 = (int(box[2] * image_size), int(box[3] * image_size))
jaccard = calc_jaccard((pt1_d[0], pt1_d[1], pt2_d[0], pt2_d[1]), (pt1[0], pt1[1], pt2[0], pt2[1]))
cv2.putText(img, f"{jaccard: .2f}", (pt1[0] + 5 , pt1[1] - 5), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 255, 0), 2)
cv2.rectangle(img, pt1=pt1, pt2=pt2, color=(255, 0, 0), thickness=2)
plt.subplot(9 , 4, i + 1)
plt.axis('off')
plt.imshow(img)
plt.show()
작은 기본 상자를 선택하는 경향이 있는 것 같지만 왜 그런지 모르겠다.베이킹 박스는 기본 모델로 vgg16에서 배운 모델을 사용하고dog 등 동물의 무게가 높아지며 교차 엔트로피가 커지는 경향이 있기 때문에 선택된 기본 박스입니다.마지막으로 학습이 끝난 후의 패턴이 경향을 바꿀 수 있는지 확인해 보세요.이번에 학습이 끝난 모형으로 사용된 무게는 아래 URL에서 다운로드할 수 있다.
그 결과 학습 전보다 더 큰 기본 상자를 선택했고dog분류가 자연스럽다고 생각하는 것도 포함됐다.이렇게 보면 SSD에 규정된 규칙과 이미지 분류 vggg16의 판정은 소극적인 기본 상자에서 항상 대립되어 보인다.
여기에 사용된 원본 이미지는 여기. (Pixabay License: 상업용 무료 귀속 표시 필요 없음) 640x426 사이즈로 다운로드한 것이다.
Reference
이 문제에 관하여(SSD의 하드 앤 드롭 결과를 시각화해 보십시오.), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://zenn.dev/nnabeyang/articles/cdeb00f52bd0b9f0ffa1텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)