|
from typing import List, Dict, Tuple |
|
import requests |
|
from elasticsearch import Elasticsearch |
|
import os |
|
import time |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
class Retriever: |
|
def __init__(self): |
|
|
|
self.es = Elasticsearch( |
|
"https://samlax12-elastic.hf.space", |
|
basic_auth=("elastic", os.getenv("PASSWORD")), |
|
verify_certs=False |
|
) |
|
self.api_key = os.getenv("API_KEY") |
|
self.api_base = os.getenv("BASE_URL") |
|
|
|
def get_embedding(self, text: str) -> List[float]: |
|
"""调用SiliconFlow的embedding API获取向量""" |
|
headers = { |
|
"Authorization": f"Bearer {self.api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
response = requests.post( |
|
f"{self.api_base}/embeddings", |
|
headers=headers, |
|
json={ |
|
"model": "BAAI/bge-m3", |
|
"input": text |
|
} |
|
) |
|
|
|
if response.status_code == 200: |
|
return response.json()["data"][0]["embedding"] |
|
else: |
|
raise Exception(f"Error getting embedding: {response.text}") |
|
|
|
def get_all_indices(self) -> List[str]: |
|
"""获取所有 RAG 相关的索引""" |
|
indices = self.es.indices.get_alias().keys() |
|
return [idx for idx in indices if idx.startswith('rag_')] |
|
|
|
def retrieve(self, query: str, top_k: int = 10, specific_index: str = None) -> Tuple[List[Dict], str]: |
|
"""混合检索:结合 BM25 和向量检索,支持指定特定索引""" |
|
|
|
if specific_index: |
|
indices = [specific_index] if self.es.indices.exists(index=specific_index) else [] |
|
else: |
|
indices = self.get_all_indices() |
|
|
|
if not indices: |
|
raise Exception("没有找到可用的文档索引!") |
|
|
|
|
|
query_vector = self.get_embedding(query) |
|
|
|
|
|
all_results = [] |
|
for index in indices: |
|
|
|
script_query = { |
|
"script_score": { |
|
"query": { |
|
"match": { |
|
"content": query |
|
} |
|
}, |
|
"script": { |
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", |
|
"params": {"query_vector": query_vector} |
|
} |
|
} |
|
} |
|
|
|
|
|
response = self.es.search( |
|
index=index, |
|
body={ |
|
"query": script_query, |
|
"size": top_k |
|
} |
|
) |
|
|
|
|
|
for hit in response['hits']['hits']: |
|
result = { |
|
'id': hit['_id'], |
|
'content': hit['_source']['content'], |
|
'score': hit['_score'], |
|
'metadata': hit['_source']['metadata'], |
|
'index': index |
|
} |
|
all_results.append(result) |
|
|
|
|
|
all_results.sort(key=lambda x: x['score'], reverse=True) |
|
top_results = all_results[:top_k] |
|
|
|
|
|
if top_results: |
|
most_relevant_index = top_results[0]['index'] |
|
else: |
|
most_relevant_index = indices[0] if indices else "" |
|
|
|
return top_results, most_relevant_index |