|
from typing import List, Dict
|
|
import requests
|
|
import time
|
|
from dotenv import load_dotenv
|
|
import os
|
|
|
|
load_dotenv()
|
|
|
|
class Reranker:
|
|
def __init__(self):
|
|
self.api_key = os.getenv("API_KEY")
|
|
self.api_base = os.getenv("BASE_URL")
|
|
|
|
def rerank(self, query: str, documents: List[Dict], index_name: str, top_k: int = 5) -> List[Dict]:
|
|
"""使用SiliconFlow的rerank API重排序文档"""
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
|
|
docs = [doc['content'] for doc in documents]
|
|
|
|
response = requests.post(
|
|
f"{self.api_base}/rerank",
|
|
headers=headers,
|
|
json={
|
|
"model": "BAAI/bge-reranker-v2-m3",
|
|
"query": query,
|
|
"documents": docs,
|
|
"top_n": top_k
|
|
}
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error in reranking: {response.text}")
|
|
|
|
|
|
results = response.json()["results"]
|
|
reranked_docs = []
|
|
|
|
for result in results:
|
|
doc_index = result["index"]
|
|
original_doc = documents[doc_index].copy()
|
|
original_doc['rerank_score'] = result["relevance_score"]
|
|
original_doc['index_name'] = index_name
|
|
reranked_docs.append(original_doc)
|
|
|
|
return reranked_docs |