ptorich에서efficientdet를 사용해 보십시오

efficientdet의pytric로 구현된rwightman/efficientdet-pytorch을 사용하여 검증 데이터의 추론 결과를 보여 줍니다.이 프로젝트에서 학습하는 방법, 사용자 정의 데이터 집합을 처리하는 방법 등에 대해 기재했지만 왜 시위와 example을 게재하지 않았는가.
우선 이 프로젝트의 맨 끝에 jupter notebook의 note를 만듭니다.
from effdet import create_model, create_dataset, create_loader
from effdet.data import resolve_input_config
import torch
import matplotlib.pyplot as plt
import cv2
import os
이 창고validate.py에서 추론 모드를 사용했기 때문에 이 프로그램의 선택할 수 있는 기본값,create 참조모델의 매개 변수를 결정합니다.checkpoint_path 스스로 공부하러 뛰어가면 아래 칸의 이름이 된다.
bench = create_model(
    'efficientdet_d0', # d0 ~ d7
    bench_task='predict',
    num_classes=20,
    pretrained=False,
    redundant_bias=None,
    soft_nms=None,
    checkpoint_path='./output/train/yyyymmdd-hhnnss-efficientdet_d0/checkpoint-n.pth.tar',
    checkpoint_ema='use_ema',
)
bench = bench.cuda() #cudaを使う
bench.eval() #推論モード
VOC 데이터 세트의 경우 다음과 같이 검증 데이터를 로드합니다._에서 훈련을 받은 측이 훈련 데이터가 되었다.
_, dataset = create_dataset('voc0712', './VOCdevkit')
데이터 로더의 옵션도 validate.py의 기본 옵션에 달려 있다.이번 가설은 너에게 사진 한 장을 줄 것이니 batch_size 1을 선택해라.
model_config = bench.config
input_config = resolve_input_config({}, model_config)
loader = create_loader(
        dataset,
        input_size=input_config['input_size'],
        batch_size=1,
        use_prefetcher=True,
        interpolation='bilinear',
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=1,
        pin_mem=False)
추론 모드에서 출력된 포위함 좌표는 입력 이미지의 크기에 따라 확률 100위 이내로 되돌아온다.cv2.circle로 cv2를 표시합니다.rectangle로 그림을 둘러싸면 결과를 확인할 수 있습니다.
IMG_DIR = '/path/to/img_dir'
parser = dataset.parser
with torch.no_grad():
    for input, target in loader:
        img_name = parser.img_infos[int(target['img_idx'][0])]['file_name']
        output = bench(input, img_info=target)[0]
        img = cv2.imread(os.path.join(IMG_DIR, img_name), cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for i in range(output.size(0)):
            if output[i, 4] < 0.3: # 0.3は閾値です。適当に変えてください
                break
            xmin, ymin, xmax, ymax, pred, label = output[i]
            #cv2.rectangle(img, pt1=(int(xmin), int(ymin)), pt2=(int(xmax), int(ymax)), color=(255, 0, 0), thickness=4)
            cx = int(xmin) + int((float(xmax) - float(xmin)) / 2)
            cy = int(ymin) + int((float(ymax) - float(ymin)) / 2)
            cv2.circle(img, (cx, cy), 5, (0, 255, 0), thickness=-1)
        plt.imshow(img)
        plt.show()

좋은 웹페이지 즐겨찾기