""" 재순위화 검색 구현 모듈 """ import logging from typing import List, Dict, Any, Optional, Union, Callable from .base_retriever import BaseRetriever logger = logging.getLogger(__name__) class ReRanker(BaseRetriever): """ 검색 결과 재순위화 검색기 """ def __init__( self, base_retriever: BaseRetriever, rerank_model: Optional[Union[str, Any]] = None, rerank_fn: Optional[Callable] = None, rerank_field: str = "text", rerank_batch_size: int = 32 ): """ ReRanker 초기화 Args: base_retriever: 기본 검색기 인스턴스 rerank_model: 재순위화 모델 (Cross-Encoder) 이름 또는 인스턴스 rerank_fn: 사용자 정의 재순위화 함수 (제공된 경우 rerank_model 대신 사용) rerank_field: 재순위화에 사용할 문서 필드 rerank_batch_size: 재순위화 모델 배치 크기 """ self.base_retriever = base_retriever self.rerank_field = rerank_field self.rerank_batch_size = rerank_batch_size self.rerank_fn = rerank_fn # 재순위화 모델 로드 (사용자 정의 함수가 제공되지 않은 경우) if rerank_fn is None and rerank_model is not None: try: from sentence_transformers import CrossEncoder if isinstance(rerank_model, str): logger.info(f"재순위화 모델 '{rerank_model}' 로드 중...") self.rerank_model = CrossEncoder(rerank_model) else: self.rerank_model = rerank_model except ImportError: logger.warning("sentence-transformers 패키지가 설치되지 않았습니다. pip install sentence-transformers 명령으로 설치하세요.") raise else: self.rerank_model = None def add_documents(self, documents: List[Dict[str, Any]]) -> None: """ 기본 검색기에 문서 추가 Args: documents: 추가할 문서 목록 """ self.base_retriever.add_documents(documents) def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]: """ 2단계 검색 수행: 기본 검색 + 재순위화 Args: query: 검색 쿼리 top_k: 최종적으로 반환할 상위 결과 수 first_stage_k: 첫 번째 단계에서 검색할 결과 수 **kwargs: 추가 검색 매개변수 Returns: 재순위화된 검색 결과 목록 """ # 첫 번째 단계: 기본 검색기로 more_k 문서 검색 logger.info(f"기본 검색기로 {first_stage_k}개 문서 검색 중...") initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs) if not initial_results: logger.warning("첫 번째 단계 검색 결과가 없습니다.") return [] if len(initial_results) < first_stage_k: logger.info(f"요청한 {first_stage_k}개보다 적은 {len(initial_results)}개 결과를 검색했습니다.") # 사용자 정의 재순위화 함수가 제공된 경우 if self.rerank_fn is not None: logger.info("사용자 정의 함수로 재순위화 중...") reranked_results = self.rerank_fn(query, initial_results) return reranked_results[:top_k] # 재순위화 모델이 로드된 경우 elif self.rerank_model is not None: logger.info(f"CrossEncoder 모델로 재순위화 중...") # 텍스트 쌍 생성 text_pairs = [] for doc in initial_results: if self.rerank_field not in doc: logger.warning(f"문서에 필드 '{self.rerank_field}'가 없습니다.") continue text_pairs.append([query, doc[self.rerank_field]]) # 모델로 점수 계산 scores = self.rerank_model.predict( text_pairs, batch_size=self.rerank_batch_size, show_progress_bar=True if len(text_pairs) > 10 else False ) # 결과 재정렬 for idx, doc in enumerate(initial_results[:len(scores)]): doc["rerank_score"] = float(scores[idx]) reranked_results = sorted( initial_results[:len(scores)], key=lambda x: x.get("rerank_score", 0), reverse=True ) return reranked_results[:top_k] # 재순위화 없이 초기 결과 반환 else: logger.info("재순위화 모델/함수가 없어 초기 검색 결과를 그대로 반환합니다.") return initial_results[:top_k]