File size: 3,595 Bytes
9c8703c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from abc import ABC, abstractmethod

import torch
from PIL.Image import Image
from typing import Dict, List, Any, Optional, TypedDict, Any
from dataclasses import dataclass, asdict


class Query(TypedDict):
    """Query Data Structure"""

    content: str
    metadata: dict[str, Any]


@dataclass
class MergedResult:
    """Result After Merging Dense and Sparse Results"""

    id: str
    content: str
    metadata: Dict[str, Any]
    sources: List[str]  # 记录来源:['dense', 'sparse']
    dense_score: Optional[float] = None
    sparse_score: Optional[float] = None
    final_score: float = 0.0

    def to_dict(self) -> dict:
        return asdict(self)


# INTERFACES


class BaseEmbeddingModel(ABC):
    """Base class for embedding models"""

    def __init__(self, config: Optional[dict[str, Any]] = None):
        self.config = config or {}
        self.device = self.config.get(
            "device", "cuda" if torch.cuda.is_available() else "cpu"
        )

    @abstractmethod
    def encode_text(self, texts: list[str]):
        """Generate embeddings for the given text"""
        ...

    # optional method for image embeddings
    def encode_image(self, images: list[str] | list[Image]):
        """Generate embeddings for the given images"""
        raise NotImplementedError("This model does not support image embeddings.")


class BaseComponent(ABC):
    """Base class for all components in the pipeline"""

    def __init__(self, config: Optional[dict] = None):
        self.config = config or {}

    @abstractmethod
    def process(self, *args, **kwargs):
        """Process method to be implemented by subclasses"""
        ...


class QueryRewriter(BaseComponent):
    """Base class for query rewriters"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(self, query: Query) -> list[Query]:
        """Rewrite the query"""
        ...


class Retriever(BaseComponent):
    """Base class for retrievers"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(self, query: list[Query], **kwargs) -> list[MergedResult]:
        """Retrieve documents based on the query"""
        ...


class Reranker(BaseComponent):
    """Base class for rerankers"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(
        self, query: Query, documents: list[MergedResult]
    ) -> list[MergedResult]:
        """Rerank the retrieved documents based on the query"""
        ...


class PromptBuilder(BaseComponent):
    """Base class for prompt builders"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(
        self,
        query: Query,
        documents: list[MergedResult],
        conversations: Optional[list[dict]] = None,
    ) -> str:
        """Build a prompt based on the query and documents"""
        ...


class Generator(BaseComponent):
    """Base class for generators"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(self, prompt: str) -> str:
        """Generate a response based on the prompt"""
        ...


class Speaker(BaseComponent):
    """Base class for speakers"""

    def __init__(self, config: Optional[dict] = None):
        super().__init__(config)

    @abstractmethod
    def process(self, text: str) -> None:
        """Convert text to speech"""
        ...