RAG6_AgenticAI / retrieval /reranker.py
jeongsoo's picture
init
58907af
"""
μž¬μˆœμœ„ν™” 검색 κ΅¬ν˜„ λͺ¨λ“ˆ
"""
import logging
from typing import List, Dict, Any, Optional, Union, Callable
from .base_retriever import BaseRetriever
logger = logging.getLogger(__name__)
class ReRanker(BaseRetriever):
"""
검색 κ²°κ³Ό μž¬μˆœμœ„ν™” 검색기
"""
def __init__(
self,
base_retriever: BaseRetriever,
rerank_model: Optional[Union[str, Any]] = None,
rerank_fn: Optional[Callable] = None,
rerank_field: str = "text",
rerank_batch_size: int = 32
):
"""
ReRanker μ΄ˆκΈ°ν™”
Args:
base_retriever: κΈ°λ³Έ 검색기 μΈμŠ€ν„΄μŠ€
rerank_model: μž¬μˆœμœ„ν™” λͺ¨λΈ (Cross-Encoder) 이름 λ˜λŠ” μΈμŠ€ν„΄μŠ€
rerank_fn: μ‚¬μš©μž μ •μ˜ μž¬μˆœμœ„ν™” ν•¨μˆ˜ (제곡된 경우 rerank_model λŒ€μ‹  μ‚¬μš©)
rerank_field: μž¬μˆœμœ„ν™”μ— μ‚¬μš©ν•  λ¬Έμ„œ ν•„λ“œ
rerank_batch_size: μž¬μˆœμœ„ν™” λͺ¨λΈ 배치 크기
"""
self.base_retriever = base_retriever
self.rerank_field = rerank_field
self.rerank_batch_size = rerank_batch_size
self.rerank_fn = rerank_fn
# μž¬μˆœμœ„ν™” λͺ¨λΈ λ‘œλ“œ (μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜κ°€ μ œκ³΅λ˜μ§€ μ•Šμ€ 경우)
if rerank_fn is None and rerank_model is not None:
try:
from sentence_transformers import CrossEncoder
if isinstance(rerank_model, str):
logger.info(f"μž¬μˆœμœ„ν™” λͺ¨λΈ '{rerank_model}' λ‘œλ“œ 쀑...")
self.rerank_model = CrossEncoder(rerank_model)
else:
self.rerank_model = rerank_model
except ImportError:
logger.warning("sentence-transformers νŒ¨ν‚€μ§€κ°€ μ„€μΉ˜λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. pip install sentence-transformers λͺ…λ ΉμœΌλ‘œ μ„€μΉ˜ν•˜μ„Έμš”.")
raise
else:
self.rerank_model = None
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
"""
κΈ°λ³Έ 검색기에 λ¬Έμ„œ μΆ”κ°€
Args:
documents: μΆ”κ°€ν•  λ¬Έμ„œ λͺ©λ‘
"""
self.base_retriever.add_documents(documents)
def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]:
"""
2단계 검색 μˆ˜ν–‰: κΈ°λ³Έ 검색 + μž¬μˆœμœ„ν™”
Args:
query: 검색 쿼리
top_k: μ΅œμ’…μ μœΌλ‘œ λ°˜ν™˜ν•  μƒμœ„ κ²°κ³Ό 수
first_stage_k: 첫 번째 λ‹¨κ³„μ—μ„œ 검색할 κ²°κ³Ό 수
**kwargs: μΆ”κ°€ 검색 λ§€κ°œλ³€μˆ˜
Returns:
μž¬μˆœμœ„ν™”λœ 검색 κ²°κ³Ό λͺ©λ‘
"""
# 첫 번째 단계: κΈ°λ³Έ κ²€μƒ‰κΈ°λ‘œ more_k λ¬Έμ„œ 검색
logger.info(f"κΈ°λ³Έ κ²€μƒ‰κΈ°λ‘œ {first_stage_k}개 λ¬Έμ„œ 검색 쀑...")
initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs)
if not initial_results:
logger.warning("첫 번째 단계 검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€.")
return []
if len(initial_results) < first_stage_k:
logger.info(f"μš”μ²­ν•œ {first_stage_k}κ°œλ³΄λ‹€ 적은 {len(initial_results)}개 κ²°κ³Όλ₯Ό κ²€μƒ‰ν–ˆμŠ΅λ‹ˆλ‹€.")
# μ‚¬μš©μž μ •μ˜ μž¬μˆœμœ„ν™” ν•¨μˆ˜κ°€ 제곡된 경우
if self.rerank_fn is not None:
logger.info("μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜λ‘œ μž¬μˆœμœ„ν™” 쀑...")
reranked_results = self.rerank_fn(query, initial_results)
return reranked_results[:top_k]
# μž¬μˆœμœ„ν™” λͺ¨λΈμ΄ λ‘œλ“œλœ 경우
elif self.rerank_model is not None:
logger.info(f"CrossEncoder λͺ¨λΈλ‘œ μž¬μˆœμœ„ν™” 쀑...")
# ν…μŠ€νŠΈ 쌍 생성
text_pairs = []
for doc in initial_results:
if self.rerank_field not in doc:
logger.warning(f"λ¬Έμ„œμ— ν•„λ“œ '{self.rerank_field}'κ°€ μ—†μŠ΅λ‹ˆλ‹€.")
continue
text_pairs.append([query, doc[self.rerank_field]])
# λͺ¨λΈλ‘œ 점수 계산
scores = self.rerank_model.predict(
text_pairs,
batch_size=self.rerank_batch_size,
show_progress_bar=True if len(text_pairs) > 10 else False
)
# κ²°κ³Ό μž¬μ •λ ¬
for idx, doc in enumerate(initial_results[:len(scores)]):
doc["rerank_score"] = float(scores[idx])
reranked_results = sorted(
initial_results[:len(scores)],
key=lambda x: x.get("rerank_score", 0),
reverse=True
)
return reranked_results[:top_k]
# μž¬μˆœμœ„ν™” 없이 초기 κ²°κ³Ό λ°˜ν™˜
else:
logger.info("μž¬μˆœμœ„ν™” λͺ¨λΈ/ν•¨μˆ˜κ°€ μ—†μ–΄ 초기 검색 κ²°κ³Όλ₯Ό κ·ΈλŒ€λ‘œ λ°˜ν™˜ν•©λ‹ˆλ‹€.")
return initial_results[:top_k]