samlax12's picture
Upload 25 files
20f7a0a verified
raw
history blame
3.94 kB
from typing import List, Dict, Tuple
import requests
from elasticsearch import Elasticsearch
import os
import time
from dotenv import load_dotenv
load_dotenv()
class Retriever:
def __init__(self):
# 使用与 vector_store.py 相同的 ES 配置
self.es = Elasticsearch(
"https://samlax12-elastic.hf.space", # 注意是 https
basic_auth=("elastic", os.getenv("PASSWORD")), # 使用相同的密码
verify_certs=False # 开发环境可以禁用证书验证
)
self.api_key = os.getenv("API_KEY")
self.api_base = os.getenv("BASE_URL")
def get_embedding(self, text: str) -> List[float]:
"""调用SiliconFlow的embedding API获取向量"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.api_base}/embeddings",
headers=headers,
json={
"model": "BAAI/bge-m3",
"input": text
}
)
if response.status_code == 200:
return response.json()["data"][0]["embedding"]
else:
raise Exception(f"Error getting embedding: {response.text}")
def get_all_indices(self) -> List[str]:
"""获取所有 RAG 相关的索引"""
indices = self.es.indices.get_alias().keys()
return [idx for idx in indices if idx.startswith('rag_')]
def retrieve(self, query: str, top_k: int = 10, specific_index: str = None) -> Tuple[List[Dict], str]:
"""混合检索:结合 BM25 和向量检索,支持指定特定索引"""
# 获取检索索引
if specific_index:
indices = [specific_index] if self.es.indices.exists(index=specific_index) else []
else:
indices = self.get_all_indices()
if not indices:
raise Exception("没有找到可用的文档索引!")
# 计算查询向量
query_vector = self.get_embedding(query)
# 在所有索引中搜索
all_results = []
for index in indices:
# 构建混合查询
script_query = {
"script_score": {
"query": {
"match": {
"content": query # BM25
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {"query_vector": query_vector}
}
}
}
# 执行检索
response = self.es.search(
index=index,
body={
"query": script_query,
"size": top_k
}
)
# 处理结果
for hit in response['hits']['hits']:
result = {
'id': hit['_id'],
'content': hit['_source']['content'],
'score': hit['_score'],
'metadata': hit['_source']['metadata'],
'index': index
}
all_results.append(result)
# 按分数排序并选择最相关的文档
all_results.sort(key=lambda x: x['score'], reverse=True)
top_results = all_results[:top_k]
# 如果有结果,返回最相关文档所在的索引
if top_results:
most_relevant_index = top_results[0]['index']
else:
most_relevant_index = indices[0] if indices else ""
return top_results, most_relevant_index