Spaces:
No application file
No application file
| """ | |
| μ¬μμν κ²μ ꡬν λͺ¨λ | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional, Union, Callable | |
| from .base_retriever import BaseRetriever | |
| logger = logging.getLogger(__name__) | |
| class ReRanker(BaseRetriever): | |
| """ | |
| κ²μ κ²°κ³Ό μ¬μμν κ²μκΈ° | |
| """ | |
| def __init__( | |
| self, | |
| base_retriever: BaseRetriever, | |
| rerank_model: Optional[Union[str, Any]] = None, | |
| rerank_fn: Optional[Callable] = None, | |
| rerank_field: str = "text", | |
| rerank_batch_size: int = 32 | |
| ): | |
| """ | |
| ReRanker μ΄κΈ°ν | |
| Args: | |
| base_retriever: κΈ°λ³Έ κ²μκΈ° μΈμ€ν΄μ€ | |
| rerank_model: μ¬μμν λͺ¨λΈ (Cross-Encoder) μ΄λ¦ λλ μΈμ€ν΄μ€ | |
| rerank_fn: μ¬μ©μ μ μ μ¬μμν ν¨μ (μ 곡λ κ²½μ° rerank_model λμ μ¬μ©) | |
| rerank_field: μ¬μμνμ μ¬μ©ν λ¬Έμ νλ | |
| rerank_batch_size: μ¬μμν λͺ¨λΈ λ°°μΉ ν¬κΈ° | |
| """ | |
| self.base_retriever = base_retriever | |
| self.rerank_field = rerank_field | |
| self.rerank_batch_size = rerank_batch_size | |
| self.rerank_fn = rerank_fn | |
| # μ¬μμν λͺ¨λΈ λ‘λ (μ¬μ©μ μ μ ν¨μκ° μ 곡λμ§ μμ κ²½μ°) | |
| if rerank_fn is None and rerank_model is not None: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| if isinstance(rerank_model, str): | |
| logger.info(f"μ¬μμν λͺ¨λΈ '{rerank_model}' λ‘λ μ€...") | |
| self.rerank_model = CrossEncoder(rerank_model) | |
| else: | |
| self.rerank_model = rerank_model | |
| except ImportError: | |
| logger.warning("sentence-transformers ν¨ν€μ§κ° μ€μΉλμ§ μμμ΅λλ€. pip install sentence-transformers λͺ λ ΉμΌλ‘ μ€μΉνμΈμ.") | |
| raise | |
| else: | |
| self.rerank_model = None | |
| def add_documents(self, documents: List[Dict[str, Any]]) -> None: | |
| """ | |
| κΈ°λ³Έ κ²μκΈ°μ λ¬Έμ μΆκ° | |
| Args: | |
| documents: μΆκ°ν λ¬Έμ λͺ©λ‘ | |
| """ | |
| self.base_retriever.add_documents(documents) | |
| def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]: | |
| """ | |
| 2λ¨κ³ κ²μ μν: κΈ°λ³Έ κ²μ + μ¬μμν | |
| Args: | |
| query: κ²μ 쿼리 | |
| top_k: μ΅μ’ μ μΌλ‘ λ°νν μμ κ²°κ³Ό μ | |
| first_stage_k: 첫 λ²μ§Έ λ¨κ³μμ κ²μν κ²°κ³Ό μ | |
| **kwargs: μΆκ° κ²μ λ§€κ°λ³μ | |
| Returns: | |
| μ¬μμνλ κ²μ κ²°κ³Ό λͺ©λ‘ | |
| """ | |
| # 첫 λ²μ§Έ λ¨κ³: κΈ°λ³Έ κ²μκΈ°λ‘ more_k λ¬Έμ κ²μ | |
| logger.info(f"κΈ°λ³Έ κ²μκΈ°λ‘ {first_stage_k}κ° λ¬Έμ κ²μ μ€...") | |
| initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs) | |
| if not initial_results: | |
| logger.warning("첫 λ²μ§Έ λ¨κ³ κ²μ κ²°κ³Όκ° μμ΅λλ€.") | |
| return [] | |
| if len(initial_results) < first_stage_k: | |
| logger.info(f"μμ²ν {first_stage_k}κ°λ³΄λ€ μ μ {len(initial_results)}κ° κ²°κ³Όλ₯Ό κ²μνμ΅λλ€.") | |
| # μ¬μ©μ μ μ μ¬μμν ν¨μκ° μ 곡λ κ²½μ° | |
| if self.rerank_fn is not None: | |
| logger.info("μ¬μ©μ μ μ ν¨μλ‘ μ¬μμν μ€...") | |
| reranked_results = self.rerank_fn(query, initial_results) | |
| return reranked_results[:top_k] | |
| # μ¬μμν λͺ¨λΈμ΄ λ‘λλ κ²½μ° | |
| elif self.rerank_model is not None: | |
| logger.info(f"CrossEncoder λͺ¨λΈλ‘ μ¬μμν μ€...") | |
| # ν μ€νΈ μ μμ± | |
| text_pairs = [] | |
| for doc in initial_results: | |
| if self.rerank_field not in doc: | |
| logger.warning(f"λ¬Έμμ νλ '{self.rerank_field}'κ° μμ΅λλ€.") | |
| continue | |
| text_pairs.append([query, doc[self.rerank_field]]) | |
| # λͺ¨λΈλ‘ μ μ κ³μ° | |
| scores = self.rerank_model.predict( | |
| text_pairs, | |
| batch_size=self.rerank_batch_size, | |
| show_progress_bar=True if len(text_pairs) > 10 else False | |
| ) | |
| # κ²°κ³Ό μ¬μ λ ¬ | |
| for idx, doc in enumerate(initial_results[:len(scores)]): | |
| doc["rerank_score"] = float(scores[idx]) | |
| reranked_results = sorted( | |
| initial_results[:len(scores)], | |
| key=lambda x: x.get("rerank_score", 0), | |
| reverse=True | |
| ) | |
| return reranked_results[:top_k] | |
| # μ¬μμν μμ΄ μ΄κΈ° κ²°κ³Ό λ°ν | |
| else: | |
| logger.info("μ¬μμν λͺ¨λΈ/ν¨μκ° μμ΄ μ΄κΈ° κ²μ κ²°κ³Όλ₯Ό κ·Έλλ‘ λ°νν©λλ€.") | |
| return initial_results[:top_k] | |