Boostcamp week12 Inference, Retrieval, QAtrainer
Baseline의 구조에 대해 요약하려한다.
Inference
1. Main function
main함수에서는 다음과 같은 순서로 코드가 실행된다.
- Argument Parser
- Logger, Dataset 정의
- Pretrained Reader model 호출
- 다음과 같이 실행
a.
eval_retrieval = True
:run_sparse_retrieval()
b.do_eval = True
/do_predict = True
:run_mrc()
2. run_sparse_retrieval
run_sparse_retrieval함수는 주어진 쿼리에 대해 topk개의 passage를 Sparse retrieval을 이용해 반환하는 함수이다.
- SparseRetrieval 클래스 선언
- 저장된 SparseRetrieval 호출
- Faiss 유무 확인, Retreive
a.use_faiss = True
: Faiss방식으로 Retrieve
b.use_faiss = False
: Exhaustive Retrieve- Predict / Eval
a.do_predict = True
: dataset feature ->context
,id
,question
b.do_eval = True
: dataset feature ->answers
,context
,id
,question
- Dataset 반환
3. run_mrc
run_mrc는 불러온 데이터에 대해 Reader가 작동하는 부분이다.
학습을 하는 코드가 아니기 때문에 Preprocess에서 validation에 해당하는 부분만 사용한다.
단, Post process에서는 다음과 같이 구분한다.
1.do_preict = True
: id, predictions 반환
2.do_eval = True
: id 별 prediction과 answer 평가
3. QA trainer에 각각 predict(), evaluate() 실행
Trainer_QA
QuestionAnsweringTrainer(Trainer)
HF의 Trainer를 상속받은 Custom trainer로, evaluate과 predict함수를 선언해줬다.
먼저, 입력 파라미터로 Trainer와 다른 점은 아래 두 파라미터이다.
eval_examples
: feature로 쪼개지기 전 데이터, example
post_process_function
또한 evaluate과 predict는 모두 prediction_loop
함수를 이용하고 있는데, Docs를 참고하자.
prediction loop 함수를 간단히 정리하면 다음과 같다.
1. prediction_loss_only 파라미터(Default는TrainingArgument
에서 False로 되어있다.)를 받는다.
2. prediction_loss_only가 True이면 예측값을 모으지 않는다.
3. prediction_loss_only가 False이면 예측값을 모아 Metric을 계산한다.
4. Evaluation 또는 prediction을 진행한다.아직 Prediction Loss Only를 왜 하는 것인지는 정확하게 이해하지 못했다.
1. evaluate()
- dataset, dataloader, eval_examples, compute_metrics를 선언한다.(혹은 인자로 받는다)
prediction_loop()
를 실행한 결과를 output으로 저장한다.- eval_examples, eval_dataset,
output.prediction
, argument를 인자로 받아 Post processing을 진행한다.- post process의 결과값으로 metric을 계산하여 logging한다.
- metric을 반환한다. -> 반환된 값은 Inference.py에서 log, save 된다.
2. predict()
- trainer의
get_test_dataloader
함수를 이용해 test dataloader선언한다.(HF trainer에는get_train_dataloader
,get_eval_dataloader
,get_test_dataloader
로 데이터로더를 선언할 수 있다.)- evaluation과 동일하게
prediction_loop()
를 실행한 결과를 output으로 저장한다.- 만약 post_processor나, compute_metric이 정의되지 않았다면 output을 그대로 반환한다.
- 그게 아니라면 Post_processing을 통해 예측값을 도출하고 이를 반환한다.
- 예측값은 postprocess_qa_predictions에서 json형태로 저장된다.
Retrieval
Sparse Retrieval Class
Retrieve를 진행하는 클래스로, Inference과정에서 호출하는 클래스이다.
0. init
- 아래 코드로 context를 정의한다(corpus)
self.contexts = list( dict.fromkeys([v["text"] for v in wiki.values()]) )
- tfidf 객체를 생성한다.
p_embedding
,indexer
를 None으로 initialize한다.
1. get_sparse_embedding
Passage Embedding을 만들고 TFIDF와 Embedding을 pickle로 저장한다. 만약 미리 저장된 파일이 있으면 저장된 pickle을 불러온다.
- 저장된 tfidf 파일이 경로에 있다면 불러온다.
- 경로에 없다면 fit_transform()한다.
2. build_faiss
속성으로 저장되어 있는 Passage Embedding을 Faiss indexer에 fitting 시켜놓는다. 이렇게 저장된 indexer는 get_relevant_doc
에서 유사도를 계산하는데 사용된다.
- 저장되어있는 indexer파일을 사용한다.
Faiss는 Build하는데 시간이 오래 걸리기 때문에 매번 새롭게 build하는 것은 비효율적이다. 그렇기 때문에 build된 index 파일을 저정하고 다음에 사용할 때 불러온다.
- 저장되어있지 않다면 IndexFlatL2, IVFScalerQuantizer를 이용해 train, add한다.
3. retrieve
str이나 Dataset으로 이루어진 Query를 받고, str 형태인 하나의 query만 받으면 get_relevant_doc
을 통해 유사도를 구한다. Dataset 형태는 query를 포함한 Huggingface.Dataset을 받고,get_relevant_doc_bulk
를 통해 유사도를 구한다.
- Query가 string일 때
get_relevant_doc
의 결과를 받아 스코어와 context를 반환한다.- Query가 dataset일 때
get_relevant_doc_bulk
의 결과를 받아 question, id, context_id, context를 반환하고 만약 context와 question, answer까지 포함된 데이터를 사용한다면 Ground Truth인 context와 answer까지 추가로 반환한다.- 이러한 정보들을 가지고 있는 데이터프레임을 최종 반환한다.
4. get_relevant_doc
위에서 봤듯 하나의 query에 대한 유사도를 계산하는 함수이다.
- query를 tfidf vectorizer를 이용해 transform한다.
- query 벡터와 passage 임베딩 벡터의 dot product를 통해 result를 출력한다.
- dim을 맞추기 위해 squeeze한 뒤 argsort를 한다.
- Topk개에 대한 score, index를 최종 반환한다.
5. get_relevant_doc_bulk
위에서 봤듯 여러개의 query에 대한 유사도를 계산해 Query개수만큼 score, index들을 반환한다.
6. retrieve_faiss
retrieve와 동일하게 구성되어있으나 get_relevant_doc_faiss
, get_relevant_doc_bulk_faiss
를 이용한다.
7. get_relevant_doc_faiss
- query를 임베딩한다
- 미리 선언해둔 indexer를 통해 search하고, score와 index를 반환한다.
8. get_relevant_doc_bulk_faiss
위와 동일하며, 여러개의 query에 대한 결과를 반환한다.
main : Retrieval.py 실행시켰을 때
Retrieval을 실행시키면 train/valid의 query를 모두 합쳐서 retrieve 결과를 보여준다.
Author And Source
이 문제에 관하여(Boostcamp week12 Inference, Retrieval, QAtrainer), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@dayday/Boostcamp-week12-Inference-Retrieval-QAtrainer저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)