【요약】 Transformer를 이용한 물체 검출 모델 「End-to-End Object Detection with Transformers」

소개



「End-to-End Object Detection with Transformers」(DETR)가 신경이 쓰였으므로, 논문을 읽고 조금 동작 확인도 해 보았습니다. 간결하게 기록으로 남겨 둡니다.
[ 논문 , Github ]

DETR이란(요약)



· Facebook AI Research가 올해 5월에 공개한 모델

・자연 언어 처리 분야에서 유명한 Transformer를 처음으로 물체 검출에 활용

· 아래 그림과 같이 CNN + Transformer의 간단한 네트워크 구성

・NMS나 AnchorBox의 디폴트치 등, 사람 손에 의한 조정이 필요한 부분을 배제해 「End-to-End」인 물체 검출을 실현

· 상기를 실현하기위한 포인트로서 "Bipartite Matching Loss"와 "Parallel Decoding"의 효과를 주장

· 물체 검출뿐만 아니라 세그멘테이션 작업에도 적용 가능



추론 코드 예



논문에서 코드를 인용합니다. 다음과 같이 모델 정의에서 추론 처리까지 40행 정도로 간단하게 쓸 수 있습니다.
import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):

    def __init__(self, num_classes, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        self.transformer = nn.Transformer(hidden_dim, nheads,
                                          num_encoder_layers, num_decoder_layers)
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1))
        return self.linear_class(h), self.linear_bbox(h).sigmoid()

detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
            num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

※이 논문의 코드의대로라면, 공개되고 있는 학습 완료 모델을 로드할 수 없습니다. (model 정의의 작성 방법이 다르기 때문에)
실제로 학습된 모델에서 검출을 실시하는 경우는, 본가 Github의 detr_demo.ipynb와 같이 하는 것이 좋다.

동작 확인



학습된 모델을 사용하여 실제로 동작을 확인해 보았습니다. 환경은 다음과 같습니다.
· OS : 우분투 18.04.4 LTS
・GPU : GeForce RTX 2060 SUPER(8GB) x1
· PyTorch 1.5.1/torchvision 0.6.0

검출 처리의 구현은 본가의 detr_demo.ipynb를 참고로 해, OpenCV로 web 카메라로부터 캡쳐한 화상에 대해서 검출을 실시했습니다. 사용한 모델은 ResNet-50 기반 DETR입니다.

다음은 실제 검출 결과입니다. 성공적으로 검색이 수행되고 있음을 확인할 수 있습니다. (COCO의 클래스에 포함되지 않는 것도 촬영 대상으로 해 버리고 있습니다만)


자신의 환경에서는, 추론 처리 자체는 45msec(약 22 FPS) 정도로 돌고 있는 것 같았습니다.
※문헌치는 V100상에서 28FPS



DETR의 공부와 동작 확인을 실시했습니다. 새로운 타입의 수법이므로, 정밀도나 속도의 면에서 앞으로 더욱 발전해 나갈 것이라고 생각합니다.
Transformer는 자연 언어 처리 전용이라고 하는 이미지를 가지고 있었습니다만, 최근에는 이미지를 취급하는 모델에의 도입도 진행되고 있군요. 앞으로 더 공부하고 싶습니다. (Image GPT도 움직이고 싶습니다.)

참고



・「DETR」Transformer의 물체 검출 데뷔
htps : // 메이 m. 이 m / lsc-psd /에서 tr t ran s fur r % 3 % 81 % Ae % 7 % 89 % A 9 % E 4 % BD % 93 % E 6 % A 4 % 9 C % 5% 87% 3% 83% 87% 3% 83% 93% 3% 83% 5% 3% 83% BC-dc18 582 c1
·End-to-End Object Detection with Transformers (DETR)의 해설
htps : // 이 m/사 s가 wy/이고 ms/61fb64d848df9f6b53d1
・Transformer를 물체 검출에 채용! 화제의 DETR을 상세 해설!
h tps : ///그래서 ps 있어. jp/2020/07/에서 tr/

좋은 웹페이지 즐겨찾기