[Faiss] ๐Ÿ˜† ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ์„ ๋น ๋ฅด๊ณ  ๊ฐ„ํŽธํžˆ!

9265 ๋‹จ์–ด machine learningmachine learning

๐Ÿ˜€ ์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜์€ ์œ ์‚ฌ๋„ ๋ฐ KNN ๊ณ„์‚ฐ์„ ๋น ๋ฅด๊ณ  ๊ฐ„ํŽธํžˆ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” ํŒจํ‚ค์ง€!
๐Ÿ˜Ž "Faiss"์— ๋Œ€ํ•ด ๊ฐ„๋‹จํžˆ ์ œ๊ฐ€ ์‚ฌ์šฉํ•œ ๋ถ€๋ถ„์„ ์ •๋ฆฌํ•˜๋ ค ํ•ฉ๋‹ˆ๋‹ค.

์‚ฌ์šฉ ๋ฐฐ๊ฒฝ

๐Ÿ˜Š ๊ณต๋ถ€ํ–ˆ๋˜ ๊ฐœ๋…๋“ค์„ ํ™œ์šฉํ•ด์„œ, ์•„๋ž˜ Kaggle ๋Œ€ํšŒ์— ์ œ์ถœํ•  ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑ์ค‘์ด์—ˆ์Šต๋‹ˆ๋‹ค!
๐ŸŒ ์•„๋ž˜ ๋Œ€ํšŒ๋Š” ๋Œ๊ณ ๋ž˜์˜ Fin(๋“ฑ์ง€๋Š๋Ÿฌ๋ฏธ)๋ฅผ ํ†ตํ•ด ๊ฐœ์ฒด๋ฅผ ๋ถ„๋ณ„ํ•˜๋Š” Task์ž…๋‹ˆ๋‹ค.
๐Ÿ’ง ๋ฐ์ดํ„ฐ์…‹ ์ž์ฒด๊ฐ€ ๋Œ๊ณ ๋ž˜ ๊ฐœ์ฒด๋งˆ๋‹ค ๋ช‡์žฅ ๋˜์ง€ ์•Š๊ณ , class์ˆ˜๋Š” ์ƒ๋‹นํžˆ ๋งŽ๊ธฐ ๋•Œ๋ฌธ์— ์ผ์ข…์˜ "Face Recognition Task" ๋กœ ์ ‘๊ทผํ•˜์—ฌ ํ•ด๊ฒฐํ•˜๋Š” ์†”๋ฃจ์…˜๋“ค์ด ๋งŽ์•˜์Šต๋‹ˆ๋‹ค.
โœจ ์ด์—, ์ €๋„ Face Recognition์— ํšจ๊ณผ๊ฐ€ ์ข‹์€ ArcFace๋ฅผ ์จ๋ณด์ž! ๋ผ๋Š” ํ๋ฆ„์ด ๋˜์—ˆ์ฃ  :)

https://www.kaggle.com/competitions/happy-whale-and-dolphin

๐Ÿ˜‹ ๊ทธ๋ ‡์Šต๋‹ˆ๋‹ค. ์ด์ „์— ์ •๋ฆฌํ–ˆ๋˜ ์•„๋ž˜ ๋‘ ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค!
๐Ÿ˜ ViT๋Š” ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋กœ, ArcFace๋Š” loss function์œผ๋กœ ํ™œ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

๐Ÿถ ViT
https://velog.io/@gtpgg1013/pytorch-Image-Classification-Using-ViT
๐Ÿ“ ArcFace
https://velog.io/@gtpgg1013/%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0-ArcFace-Additive-Angular-Margin-Loss-for-Deep-Face-Recognition

๊ทธ๋Ÿฌ๋‚˜!

ํ•˜ํ•ซ.. ๊ฑฐ์ฐธ ๐Ÿคฃ
์ œ๊ฐ€ ์˜ˆ์ „์— ํ‰์†Œ์— ์ž์ฃผ ์‚ฌ์šฉํ•˜๋˜ Softmax layer์—์„œ ์ตœ๋Œ“๊ฐ’์„ ๋ฝ‘์•„ ๊ฒฐ๊ด๊ฐ’์„ ๋งŒ๋“ค๋˜ ๋ฐฉ์‹์ด ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์—,
๊ฒฐ๊ณผ๋ฅผ ์ œ์ถœํ•˜๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€์ ์ธ ํ›„์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•œ ์ƒํ™ฉ์ด ๋ฐœ์ƒํ•˜์˜€์Šต๋‹ˆ๋‹ค.

(์ฐธ๊ณ )
๐Ÿ˜ ๋ณดํ†ต ์ผ๋ฐ˜์ ์ธ Classification Task๋Š” softmax layer์—์„œ ๊ฐ’์„ ๋ฝ‘์•„์„œ, ์ด๋ฅผ argmax๋ฅผ ์ทจํ•œ ๊ฐ’์„ ๊ฒฐ๊ณผ index๋กœ ์‰ฝ๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๐Ÿค” ํ•˜์ง€๋งŒ, ArcFace loss function์„ ํ™œ์šฉํ•œ ์ถ”๋ก  ๊ฒฐ๊ณผ๋Š” ๊ฐ image๋“ค ํŠน์ • dim์œผ๋กœ Embeddingํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ํ…Œ์ŠคํŠธ ๋ฐ ํ•™์Šต ๋ฐ์ดํ„ฐ๋“ค์„ ๋ชจ๋‘ ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜์— Embedding์„ ๋งŒ๋“ค๊ณ , ๊ฐ Embedding์˜ ์œ ์‚ฌ๋„๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ฐœ์ฒด ์ธ์‹์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ˜’ ๋ฌผ๋ก  ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ•˜๋Š” for loop๋ฅผ ๊ทธ๋ƒฅ ์งœ๋„ ๋ฉ๋‹ˆ๋‹ค๋งŒ...
๐Ÿ˜‹ ๋ถ„๋ช… ๋ˆ„๊ฐ€ ์ข‹์€ ๊ฒƒ์„ ๋ฏธ๋ฆฌ ๋งŒ๋“ค์–ด ๋†“์•˜์„ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ–ˆ๊ณ , ์ฐธ๊ณ ํ•œ ์†”๋ฃจ์…˜์—์„œ ์ฐพ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!

Faiss!

github : https://github.com/facebookresearch/faiss
wiki : https://github.com/facebookresearch/faiss/wiki

โœจ Faiss๋Š” Facebook์—์„œ ๋งŒ๋“  ํšจ์œจ์ ์ธ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰ ๋ฐ clustering์„ ์œ„ํ•ด์„œ ๋งŒ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ์„œ, C++๋กœ ์งœ์—ฌ์žˆ์œผ๋ฉฐ, GPU ํ™œ์šฉ๊นŒ์ง€ ๊ฐ€๋Šฅํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ™œ์šฉ๋„ ๋ฐ ํšจ์œจ์„ฑ์ด ๋†’๋‹ค๊ณ  ํ•˜๋„ค์š” :)

๐Ÿ™Œ ์ €๋Š” L2 distance์„ ํ™œ์šฉํ•˜์—ฌ Embedding๊ฐ„ distance๋ฅผ ๊ตฌํ•˜์˜€์Šต๋‹ˆ๋‹ค :)

install

pip install faiss-gpu

simple faiss api example

๐Ÿ˜€ Faiss๋Š” index ์ƒ์„ฑ => index์— db ๋“ฑ๋ก => query๋กœ db search ์ˆœ์„œ๋กœ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.
๐Ÿ˜Ž index.search ํ•จ์ˆ˜์˜ ์ธ์ž k ๊ฐฏ์ˆ˜๋งŒํผ ์œ ์‚ฌํ•œ ์ˆœ์„œ๋Œ€๋กœ ์ฐพ์Šต๋‹ˆ๋‹ค.
๐Ÿ˜‹ ๊ทธ๋ฆฌ๊ณ  I๋Š” db์˜ index / D๋Š” query์™€ ํ•ด๋‹น index์˜ db์™€์˜ distance ์ž…๋‹ˆ๋‹ค.

import numpy as np
import faiss

# ํƒ์ƒ‰ ๋Œ€์ƒ : db
db = np.array(np.random.random((100,32)), np.float32)
# db์— ์งˆ์˜ : query
query = np.array(np.random.random((2,32)), np.float32)

# ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
def create_and_search_index(embedding_size, db_embeddings, query_embeddings, k):
	# ํŠน์ • embedding size(32)์˜ faiss index ์ƒ์„ฑ
    index = faiss.IndexFlatL2(embedding_size)
    # db ๋“ฑ๋ก
    index.add(db_embeddings)
    # k๊ฐœ์˜ ์œ ์‚ฌํ•œ ๊ฐ’ search
    # I๋Š” db์˜ index / D๋Š” query์™€ ํ•ด๋‹น index์˜ db์™€์˜ distance
    D, I = index.search(query_embeddings, k=k) 

    return D, I
    
D, I = create_and_search_index(32, db, query, 5)

D
>>> array([[2.2877078, 2.4738002, 2.5330005, 2.8156984, 2.8736565],
       [2.5150692, 2.5936663, 3.0449693, 3.14107  , 3.198256 ]],
      dtype=float32)

# ์ฆ‰, ์ฒซ๋ฒˆ์งธ query์™€ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด db์˜ index๋Š” 35, 2๋ฒˆ์งธ query์™€ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด db์˜ index๋Š” 10์ด๋‹ค.
I
>>> array([[35, 10, 83, 85, 55],
       [10, 84, 89, 44, 92]])

๐ŸŽ‰ ์œ„์™€ ๊ฐ™์ด ํ™œ์šฉํ•˜๋ฉด ๊ฐ„๋‹จํžˆ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ์„ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค :)
๐Ÿ˜‰ L2 Distance ์ด์™ธ์—๋„ ๋‹ค์–‘ํ•œ distance๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ์œผ๋‹ˆ, ์•„๋ž˜ ํŽ˜์ด์ง€๋ฅผ ์ฐธ์กฐํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค.

https://github.com/facebookresearch/faiss/wiki/Faiss-indexes

๊ธ€์„ ์ •๋ฆฌํ•˜๋ฉฐ

๐Ÿ˜ ๊ฐ•๋ ฅํ•œ ๊ธฐ๋Šฅ์œผ๋กœ ์œ ์‚ฌ๋„๋ฅผ ํ•œ๋ฐฉ์— ์ •๋ฆฌํ•ด์ฃผ๋„๋ก ๋„์™€์ค€ ํŒจํ‚ค์ง€ Faiss์— ๋Œ€ํ•ด ์ •๋ฆฌํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค.
๐Ÿ˜Š ํ™•์‹คํžˆ ์ถ”์ฒœ ์‹œ์Šคํ…œ์ด๋‚˜ ์œ ์‚ฌ๋„ ๋น„๊ต๊ฐ€ ๋งŽ์€ ํ”„๋กœ์„ธ์Šค์—์„œ ์œ ์šฉํ•˜๊ฒŒ ์“ฐ์ผ ๊ฒƒ ๊ฐ™๋‹ค๋Š” ์ƒ๊ฐ์ด ๋“œ๋„ค์š”.
๐Ÿฑโ€๐Ÿ‘ค ๋‹ค์Œ์—๋Š” GPU ๊ธฐ๋Šฅ๋„ ํ™œ์šฉํ•ด๋ณด๊ณ , ํ•„์š”ํ•œ ์ƒํ™ฉ์— ๋งž์ถฐ distance๋„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐํšŒ๊ฐ€ ๊ณง ์˜ค์ง€์•Š์„๊นŒ ์ƒ๊ฐ์ด ๋“ญ๋‹ˆ๋‹ค.
๐Ÿ˜ ๊ทธ๋Ÿผ, ์—ฌ๊ธฐ๊นŒ์ง€ ์ฝ์–ด์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ์ข‹์€ ํ•˜๋ฃจ ๋˜์„ธ์š”!

์ข‹์€ ์›นํŽ˜์ด์ง€ ์ฆ๊ฒจ์ฐพ๊ธฐ