Spaces:
Sleeping
Sleeping
File size: 5,073 Bytes
64371be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
์ฌ์์ํ ๊ฒ์ ๊ตฌํ ๋ชจ๋
"""
import logging
from typing import List, Dict, Any, Optional, Union, Callable
from .base_retriever import BaseRetriever
logger = logging.getLogger(__name__)
class ReRanker(BaseRetriever):
"""
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์์ํ ๊ฒ์๊ธฐ
"""
def __init__(
self,
base_retriever: BaseRetriever,
rerank_model: Optional[Union[str, Any]] = None,
rerank_fn: Optional[Callable] = None,
rerank_field: str = "text",
rerank_batch_size: int = 32
):
"""
ReRanker ์ด๊ธฐํ
Args:
base_retriever: ๊ธฐ๋ณธ ๊ฒ์๊ธฐ ์ธ์คํด์ค
rerank_model: ์ฌ์์ํ ๋ชจ๋ธ (Cross-Encoder) ์ด๋ฆ ๋๋ ์ธ์คํด์ค
rerank_fn: ์ฌ์ฉ์ ์ ์ ์ฌ์์ํ ํจ์ (์ ๊ณต๋ ๊ฒฝ์ฐ rerank_model ๋์ ์ฌ์ฉ)
rerank_field: ์ฌ์์ํ์ ์ฌ์ฉํ ๋ฌธ์ ํ๋
rerank_batch_size: ์ฌ์์ํ ๋ชจ๋ธ ๋ฐฐ์น ํฌ๊ธฐ
"""
self.base_retriever = base_retriever
self.rerank_field = rerank_field
self.rerank_batch_size = rerank_batch_size
self.rerank_fn = rerank_fn
# ์ฌ์์ํ ๋ชจ๋ธ ๋ก๋ (์ฌ์ฉ์ ์ ์ ํจ์๊ฐ ์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ)
if rerank_fn is None and rerank_model is not None:
try:
from sentence_transformers import CrossEncoder
if isinstance(rerank_model, str):
logger.info(f"์ฌ์์ํ ๋ชจ๋ธ '{rerank_model}' ๋ก๋ ์ค...")
self.rerank_model = CrossEncoder(rerank_model)
else:
self.rerank_model = rerank_model
except ImportError:
logger.warning("sentence-transformers ํจํค์ง๊ฐ ์ค์น๋์ง ์์์ต๋๋ค. pip install sentence-transformers ๋ช
๋ น์ผ๋ก ์ค์นํ์ธ์.")
raise
else:
self.rerank_model = None
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
"""
๊ธฐ๋ณธ ๊ฒ์๊ธฐ์ ๋ฌธ์ ์ถ๊ฐ
Args:
documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก
"""
self.base_retriever.add_documents(documents)
def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]:
"""
2๋จ๊ณ ๊ฒ์ ์ํ: ๊ธฐ๋ณธ ๊ฒ์ + ์ฌ์์ํ
Args:
query: ๊ฒ์ ์ฟผ๋ฆฌ
top_k: ์ต์ข
์ ์ผ๋ก ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์
first_stage_k: ์ฒซ ๋ฒ์งธ ๋จ๊ณ์์ ๊ฒ์ํ ๊ฒฐ๊ณผ ์
**kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์
Returns:
์ฌ์์ํ๋ ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
# ์ฒซ ๋ฒ์งธ ๋จ๊ณ: ๊ธฐ๋ณธ ๊ฒ์๊ธฐ๋ก more_k ๋ฌธ์ ๊ฒ์
logger.info(f"๊ธฐ๋ณธ ๊ฒ์๊ธฐ๋ก {first_stage_k}๊ฐ ๋ฌธ์ ๊ฒ์ ์ค...")
initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs)
if not initial_results:
logger.warning("์ฒซ ๋ฒ์งธ ๋จ๊ณ ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค.")
return []
if len(initial_results) < first_stage_k:
logger.info(f"์์ฒญํ {first_stage_k}๊ฐ๋ณด๋ค ์ ์ {len(initial_results)}๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ฒ์ํ์ต๋๋ค.")
# ์ฌ์ฉ์ ์ ์ ์ฌ์์ํ ํจ์๊ฐ ์ ๊ณต๋ ๊ฒฝ์ฐ
if self.rerank_fn is not None:
logger.info("์ฌ์ฉ์ ์ ์ ํจ์๋ก ์ฌ์์ํ ์ค...")
reranked_results = self.rerank_fn(query, initial_results)
return reranked_results[:top_k]
# ์ฌ์์ํ ๋ชจ๋ธ์ด ๋ก๋๋ ๊ฒฝ์ฐ
elif self.rerank_model is not None:
logger.info(f"CrossEncoder ๋ชจ๋ธ๋ก ์ฌ์์ํ ์ค...")
# ํ
์คํธ ์ ์์ฑ
text_pairs = []
for doc in initial_results:
if self.rerank_field not in doc:
logger.warning(f"๋ฌธ์์ ํ๋ '{self.rerank_field}'๊ฐ ์์ต๋๋ค.")
continue
text_pairs.append([query, doc[self.rerank_field]])
# ๋ชจ๋ธ๋ก ์ ์ ๊ณ์ฐ
scores = self.rerank_model.predict(
text_pairs,
batch_size=self.rerank_batch_size,
show_progress_bar=True if len(text_pairs) > 10 else False
)
# ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ
for idx, doc in enumerate(initial_results[:len(scores)]):
doc["rerank_score"] = float(scores[idx])
reranked_results = sorted(
initial_results[:len(scores)],
key=lambda x: x.get("rerank_score", 0),
reverse=True
)
return reranked_results[:top_k]
# ์ฌ์์ํ ์์ด ์ด๊ธฐ ๊ฒฐ๊ณผ ๋ฐํ
else:
logger.info("์ฌ์์ํ ๋ชจ๋ธ/ํจ์๊ฐ ์์ด ์ด๊ธฐ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋๋ก ๋ฐํํฉ๋๋ค.")
return initial_results[:top_k]
|