Spaces:
Runtime error
Runtime error
File size: 3,595 Bytes
9c8703c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from abc import ABC, abstractmethod
import torch
from PIL.Image import Image
from typing import Dict, List, Any, Optional, TypedDict, Any
from dataclasses import dataclass, asdict
class Query(TypedDict):
"""Query Data Structure"""
content: str
metadata: dict[str, Any]
@dataclass
class MergedResult:
"""Result After Merging Dense and Sparse Results"""
id: str
content: str
metadata: Dict[str, Any]
sources: List[str] # 记录来源:['dense', 'sparse']
dense_score: Optional[float] = None
sparse_score: Optional[float] = None
final_score: float = 0.0
def to_dict(self) -> dict:
return asdict(self)
# INTERFACES
class BaseEmbeddingModel(ABC):
"""Base class for embedding models"""
def __init__(self, config: Optional[dict[str, Any]] = None):
self.config = config or {}
self.device = self.config.get(
"device", "cuda" if torch.cuda.is_available() else "cpu"
)
@abstractmethod
def encode_text(self, texts: list[str]):
"""Generate embeddings for the given text"""
...
# optional method for image embeddings
def encode_image(self, images: list[str] | list[Image]):
"""Generate embeddings for the given images"""
raise NotImplementedError("This model does not support image embeddings.")
class BaseComponent(ABC):
"""Base class for all components in the pipeline"""
def __init__(self, config: Optional[dict] = None):
self.config = config or {}
@abstractmethod
def process(self, *args, **kwargs):
"""Process method to be implemented by subclasses"""
...
class QueryRewriter(BaseComponent):
"""Base class for query rewriters"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(self, query: Query) -> list[Query]:
"""Rewrite the query"""
...
class Retriever(BaseComponent):
"""Base class for retrievers"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(self, query: list[Query], **kwargs) -> list[MergedResult]:
"""Retrieve documents based on the query"""
...
class Reranker(BaseComponent):
"""Base class for rerankers"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(
self, query: Query, documents: list[MergedResult]
) -> list[MergedResult]:
"""Rerank the retrieved documents based on the query"""
...
class PromptBuilder(BaseComponent):
"""Base class for prompt builders"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(
self,
query: Query,
documents: list[MergedResult],
conversations: Optional[list[dict]] = None,
) -> str:
"""Build a prompt based on the query and documents"""
...
class Generator(BaseComponent):
"""Base class for generators"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(self, prompt: str) -> str:
"""Generate a response based on the prompt"""
...
class Speaker(BaseComponent):
"""Base class for speakers"""
def __init__(self, config: Optional[dict] = None):
super().__init__(config)
@abstractmethod
def process(self, text: str) -> None:
"""Convert text to speech"""
...
|