딥러닝에서 누구나 유사한 이미지 검색

배경



"어느 이미지와 가장 유사한 이미지를 탐색하고 싶다"경우에, 어떻게 딥 러닝을 활용할 수 있을지 지견이 없었으므로, 이하의 2개를 참고로 변경하고 있습니다.

근사 최근 이웃 탐색 라이브러리 비교
기계 학습 모델이라는 "함수"를 사용하여 처음으로 유사한 이미지 검색

annoy는 설치가 번거롭기 때문에 nmslib를 사용합니다.

결과



내용 해설 전에 결과를 나타냅니다.

한 이미지(고양이)와 비슷한 이미지를 데이터베이스(20장의 새끼 고양이, 강아지의 이미지) 내에서 검색한 결과입니다.
1장을 제외하고, 고양이 화상이 비슷하다고 판정되고 있었습니다.

*오판정의 1장은 피물을 하고, 매우 고양이에는 보이지 않기 때문에 오판정도 납득의 결과입니다.



개요



탐색 대상의 이미지군에는 라벨링은 되어 있지 않기 때문에, 교사 있어 학습을 할 수 없습니다.
하나의 접근법으로는 학습된 모델을 활용하여 전체 결합층의 출력 등을 특징량으로 사용할 수 있습니다.

이번에는 VGG19에서 Imagenet을 학습한 가중치를 사용했습니다.

또, tkinter로 데이터베이스와 화상의 선택을 할 수 있도록 하고 있기 때문에,
프로그램 미경험자라도 움직일 수 있을까 생각합니다.

코드


import glob
from pathlib import Path
import tkinter
import tkinter.filedialog

#License
#The MIT License
import keras
from keras.models import Model
from keras.layers import Input, Dense
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input

#License
#These weights are ported from the ones released by VGG at Oxford under the Creative Commons Attribution License.
#https://keras.io/applications/
from keras.applications.vgg19 import VGG19, preprocess_input

#Apache License Version 2.0
#https://github.com/nmslib/nmslib/blob/master/README.md
import nmslib

#https://numpy.org/license.html
import numpy as np

current_path = Path.cwd()

# refer https://qiita.com/wasnot/items/20c4f30a529ae3ed5f52
# refer https://qiita.com/K-jun/items/cab923d49a939a8486fc

def main():
    print("データベースを選択してください")
    print("サブディレクトリ内の画像もすべて検索対象となります")

    data_folder_path = tkinter.filedialog.askdirectory(initialdir = current_path,
                        title = 'choose data folder')

    print("データベースと比較したい画像を選択してください")
    test_img_path = tkinter.filedialog.askopenfilename(initialdir = current_path,
                        title = 'choose test image', filetypes = [('image file', '*.jpeg;*jpg;*png')])

    base_model = VGG19(weights="imagenet")
    base_model.summary()
    #outputsを"fc2"と指定し、2番目の全結合層を出力します
    model = Model(inputs=base_model.input, outputs=base_model.get_layer("fc2").output)

    test_img = image.load_img(test_img_path, target_size=(224, 224))
    x = image.img_to_array(test_img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    test_fc2_features = model.predict(x)

    #選択したフォルダに存在するpng,jpeg,jpgをサブディレクトリも含めて抽出
    png_list  = glob.glob(data_folder_path + "/**/*.png", recursive=True)
    jpeg_list = glob.glob(data_folder_path + "/**/*.jpeg", recursive=True)
    jpg_list  = glob.glob(data_folder_path + "/**/*.jpg", recursive=True)
    image_list = png_list + jpeg_list + jpg_list

    fc2_list = []
    for image_path in image_list:
        img = image.load_img(image_path, target_size=(224, 224))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        fc2_features = model.predict(x)
        fc2_list.append(fc2_features[0])

    index = nmslib.init(method='hnsw', space='cosinesimil')
    index.addDataPointBatch(fc2_list)
    index.createIndex({'post': 2}, print_progress=True)
    ids, distances = index.knnQuery(test_fc2_features, k=len(image_list))
    result = [image_list[i] for i in ids]

    print(ids)
    print(distances)
    print(result)

    print("選択した画像は " , test_img_path, " です")
    print("選択した画像に似ている順に表示します")
    for i, id in enumerate(ids):
        print(image_list[id], " : 距離: ", distances[i])

main()

좋은 웹페이지 즐겨찾기