File size: 5,073 Bytes
2382288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
์žฌ์ˆœ์œ„ํ™” ๊ฒ€์ƒ‰ ๊ตฌํ˜„ ๋ชจ๋“ˆ
"""

import logging
from typing import List, Dict, Any, Optional, Union, Callable
from .base_retriever import BaseRetriever

logger = logging.getLogger(__name__)

class ReRanker(BaseRetriever):
    """
    ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์žฌ์ˆœ์œ„ํ™” ๊ฒ€์ƒ‰๊ธฐ
    """
    
    def __init__(
        self, 
        base_retriever: BaseRetriever,
        rerank_model: Optional[Union[str, Any]] = None,
        rerank_fn: Optional[Callable] = None,
        rerank_field: str = "text",
        rerank_batch_size: int = 32
    ):
        """
        ReRanker ์ดˆ๊ธฐํ™”
        
        Args:
            base_retriever: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ ์ธ์Šคํ„ด์Šค
            rerank_model: ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ (Cross-Encoder) ์ด๋ฆ„ ๋˜๋Š” ์ธ์Šคํ„ด์Šค
            rerank_fn: ์‚ฌ์šฉ์ž ์ •์˜ ์žฌ์ˆœ์œ„ํ™” ํ•จ์ˆ˜ (์ œ๊ณต๋œ ๊ฒฝ์šฐ rerank_model ๋Œ€์‹  ์‚ฌ์šฉ)
            rerank_field: ์žฌ์ˆœ์œ„ํ™”์— ์‚ฌ์šฉํ•  ๋ฌธ์„œ ํ•„๋“œ
            rerank_batch_size: ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ ๋ฐฐ์น˜ ํฌ๊ธฐ
        """
        self.base_retriever = base_retriever
        self.rerank_field = rerank_field
        self.rerank_batch_size = rerank_batch_size
        self.rerank_fn = rerank_fn
        
        # ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ ๋กœ๋“œ (์‚ฌ์šฉ์ž ์ •์˜ ํ•จ์ˆ˜๊ฐ€ ์ œ๊ณต๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
        if rerank_fn is None and rerank_model is not None:
            try:
                from sentence_transformers import CrossEncoder
                if isinstance(rerank_model, str):
                    logger.info(f"์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ '{rerank_model}' ๋กœ๋“œ ์ค‘...")
                    self.rerank_model = CrossEncoder(rerank_model)
                else:
                    self.rerank_model = rerank_model
            except ImportError:
                logger.warning("sentence-transformers ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. pip install sentence-transformers ๋ช…๋ น์œผ๋กœ ์„ค์น˜ํ•˜์„ธ์š”.")
                raise
        else:
            self.rerank_model = None
    
    def add_documents(self, documents: List[Dict[str, Any]]) -> None:
        """
        ๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ์— ๋ฌธ์„œ ์ถ”๊ฐ€
        
        Args:
            documents: ์ถ”๊ฐ€ํ•  ๋ฌธ์„œ ๋ชฉ๋ก
        """
        self.base_retriever.add_documents(documents)
    
    def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]:
        """
        2๋‹จ๊ณ„ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰ + ์žฌ์ˆœ์œ„ํ™”
        
        Args:
            query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
            top_k: ์ตœ์ข…์ ์œผ๋กœ ๋ฐ˜ํ™˜ํ•  ์ƒ์œ„ ๊ฒฐ๊ณผ ์ˆ˜
            first_stage_k: ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„์—์„œ ๊ฒ€์ƒ‰ํ•  ๊ฒฐ๊ณผ ์ˆ˜
            **kwargs: ์ถ”๊ฐ€ ๊ฒ€์ƒ‰ ๋งค๊ฐœ๋ณ€์ˆ˜
        
        Returns:
            ์žฌ์ˆœ์œ„ํ™”๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
        """
        # ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ๋กœ more_k ๋ฌธ์„œ ๊ฒ€์ƒ‰
        logger.info(f"๊ธฐ๋ณธ ๊ฒ€์ƒ‰๊ธฐ๋กœ {first_stage_k}๊ฐœ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘...")
        initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs)
        
        if not initial_results:
            logger.warning("์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
            return []
        
        if len(initial_results) < first_stage_k:
            logger.info(f"์š”์ฒญํ•œ {first_stage_k}๊ฐœ๋ณด๋‹ค ์ ์€ {len(initial_results)}๊ฐœ ๊ฒฐ๊ณผ๋ฅผ ๊ฒ€์ƒ‰ํ–ˆ์Šต๋‹ˆ๋‹ค.")
        
        # ์‚ฌ์šฉ์ž ์ •์˜ ์žฌ์ˆœ์œ„ํ™” ํ•จ์ˆ˜๊ฐ€ ์ œ๊ณต๋œ ๊ฒฝ์šฐ
        if self.rerank_fn is not None:
            logger.info("์‚ฌ์šฉ์ž ์ •์˜ ํ•จ์ˆ˜๋กœ ์žฌ์ˆœ์œ„ํ™” ์ค‘...")
            reranked_results = self.rerank_fn(query, initial_results)
            return reranked_results[:top_k]
        
        # ์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ์ด ๋กœ๋“œ๋œ ๊ฒฝ์šฐ
        elif self.rerank_model is not None:
            logger.info(f"CrossEncoder ๋ชจ๋ธ๋กœ ์žฌ์ˆœ์œ„ํ™” ์ค‘...")
            
            # ํ…์ŠคํŠธ ์Œ ์ƒ์„ฑ
            text_pairs = []
            for doc in initial_results:
                if self.rerank_field not in doc:
                    logger.warning(f"๋ฌธ์„œ์— ํ•„๋“œ '{self.rerank_field}'๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
                    continue
                text_pairs.append([query, doc[self.rerank_field]])
            
            # ๋ชจ๋ธ๋กœ ์ ์ˆ˜ ๊ณ„์‚ฐ
            scores = self.rerank_model.predict(
                text_pairs, 
                batch_size=self.rerank_batch_size,
                show_progress_bar=True if len(text_pairs) > 10 else False
            )
            
            # ๊ฒฐ๊ณผ ์žฌ์ •๋ ฌ
            for idx, doc in enumerate(initial_results[:len(scores)]):
                doc["rerank_score"] = float(scores[idx])
            
            reranked_results = sorted(
                initial_results[:len(scores)], 
                key=lambda x: x.get("rerank_score", 0), 
                reverse=True
            )
            
            return reranked_results[:top_k]
        
        # ์žฌ์ˆœ์œ„ํ™” ์—†์ด ์ดˆ๊ธฐ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
        else:
            logger.info("์žฌ์ˆœ์œ„ํ™” ๋ชจ๋ธ/ํ•จ์ˆ˜๊ฐ€ ์—†์–ด ์ดˆ๊ธฐ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.")
            return initial_results[:top_k]