jeongsoo's picture
init
6575706
"""
์žฌ์ˆœ์œ„ํ™” ๊ฒ€์ƒ‰ ๊ตฌํ˜„ ๋ชจ๋“ˆ
"""
import logging
from typing import List, Dict, Any, Optional, Union, Callable
from .base_retriever import BaseRetriever
logger = logging.getLogger(__name__)
class ReRanker(BaseRetriever):
"""
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์žฌ์ˆœ์œ„ํ™” ๊ฒ€์ƒ‰๊ธฐ
"""
def __init__(
self,
base_retriever: BaseRetriever,
rerank_model: Optional[Union[str, Any]] = None,
rerank_fn: Optional[Callable] = None,
rerank_field: str = "text",
rerank_batch_size: int = 32
):
"""
ReRanker ์ดˆ๊ธฐํ™”
Args:
base_retriever: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ ์ธ์Šคํ„ด์Šค
rerank_model: ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ (Cross-Encoder) ์ด๋ฆ„ ๋˜๋Š” ์ธ์Šคํ„ด์Šค
rerank_fn: ์‚ฌ์šฉ์ž ์ •์˜ ์žฌ์ˆœ์œ„ํ™” ํ•จ์ˆ˜ (์ œ๊ณต๋œ ๊ฒฝ์šฐ rerank_model ๋Œ€์‹  ์‚ฌ์šฉ)
rerank_field: ์žฌ์ˆœ์œ„ํ™”์— ์‚ฌ์šฉํ•  ๋ฌธ์„œ ํ•„๋“œ
rerank_batch_size: ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ ๋ฐฐ์น˜ ํฌ๊ธฐ
"""
self.base_retriever = base_retriever
self.rerank_field = rerank_field
self.rerank_batch_size = rerank_batch_size
self.rerank_fn = rerank_fn
# ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ ๋กœ๋“œ (์‚ฌ์šฉ์ž ์ •์˜ ํ•จ์ˆ˜๊ฐ€ ์ œ๊ณต๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
if rerank_fn is None and rerank_model is not None:
try:
from sentence_transformers import CrossEncoder
if isinstance(rerank_model, str):
logger.info(f"์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ '{rerank_model}' ๋กœ๋“œ ์ค‘...")
self.rerank_model = CrossEncoder(rerank_model)
else:
self.rerank_model = rerank_model
except ImportError:
logger.warning("sentence-transformers ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. pip install sentence-transformers ๋ช…๋ น์œผ๋กœ ์„ค์น˜ํ•˜์„ธ์š”.")
raise
else:
self.rerank_model = None
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
"""
๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ์— ๋ฌธ์„œ ์ถ”๊ฐ€
Args:
documents: ์ถ”๊ฐ€ํ•  ๋ฌธ์„œ ๋ชฉ๋ก
"""
self.base_retriever.add_documents(documents)
def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]:
"""
2๋‹จ๊ณ„ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰ + ์žฌ์ˆœ์œ„ํ™”
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
top_k: ์ตœ์ข…์ ์œผ๋กœ ๋ฐ˜ํ™˜ํ•  ์ƒ์œ„ ๊ฒฐ๊ณผ ์ˆ˜
first_stage_k: ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„์—์„œ ๊ฒ€์ƒ‰ํ•  ๊ฒฐ๊ณผ ์ˆ˜
**kwargs: ์ถ”๊ฐ€ ๊ฒ€์ƒ‰ ๋งค๊ฐœ๋ณ€์ˆ˜
Returns:
์žฌ์ˆœ์œ„ํ™”๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
# ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ๋กœ more_k ๋ฌธ์„œ ๊ฒ€์ƒ‰
logger.info(f"๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ๋กœ {first_stage_k}๊ฐœ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘...")
initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs)
if not initial_results:
logger.warning("์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return []
if len(initial_results) < first_stage_k:
logger.info(f"์š”์ฒญํ•œ {first_stage_k}๊ฐœ๋ณด๋‹ค ์ ์€ {len(initial_results)}๊ฐœ ๊ฒฐ๊ณผ๋ฅผ ๊ฒ€์ƒ‰ํ–ˆ์Šต๋‹ˆ๋‹ค.")
# ์‚ฌ์šฉ์ž ์ •์˜ ์žฌ์ˆœ์œ„ํ™” ํ•จ์ˆ˜๊ฐ€ ์ œ๊ณต๋œ ๊ฒฝ์šฐ
if self.rerank_fn is not None:
logger.info("์‚ฌ์šฉ์ž ์ •์˜ ํ•จ์ˆ˜๋กœ ์žฌ์ˆœ์œ„ํ™” ์ค‘...")
reranked_results = self.rerank_fn(query, initial_results)
return reranked_results[:top_k]
# ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ์ด ๋กœ๋“œ๋œ ๊ฒฝ์šฐ
elif self.rerank_model is not None:
logger.info(f"CrossEncoder ๋ชจ๋ธ๋กœ ์žฌ์ˆœ์œ„ํ™” ์ค‘...")
# ํ…์ŠคํŠธ ์Œ ์ƒ์„ฑ
text_pairs = []
for doc in initial_results:
if self.rerank_field not in doc:
logger.warning(f"๋ฌธ์„œ์— ํ•„๋“œ '{self.rerank_field}'๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
continue
text_pairs.append([query, doc[self.rerank_field]])
# ๋ชจ๋ธ๋กœ ์ ์ˆ˜ ๊ณ„์‚ฐ
scores = self.rerank_model.predict(
text_pairs,
batch_size=self.rerank_batch_size,
show_progress_bar=True if len(text_pairs) > 10 else False
)
# ๊ฒฐ๊ณผ ์žฌ์ •๋ ฌ
for idx, doc in enumerate(initial_results[:len(scores)]):
doc["rerank_score"] = float(scores[idx])
reranked_results = sorted(
initial_results[:len(scores)],
key=lambda x: x.get("rerank_score", 0),
reverse=True
)
return reranked_results[:top_k]
# ์žฌ์ˆœ์œ„ํ™” ์—†์ด ์ดˆ๊ธฐ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
else:
logger.info("์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ/ํ•จ์ˆ˜๊ฐ€ ์—†์–ด ์ดˆ๊ธฐ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.")
return initial_results[:top_k]