|
from typing import List, Dict |
|
import requests |
|
import numpy as np |
|
from elasticsearch import Elasticsearch |
|
import urllib3 |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
load_dotenv() |
|
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
|
|
|
class VectorStore: |
|
def __init__(self): |
|
|
|
self.es = Elasticsearch( |
|
"https://samlax12-elastic.hf.space", |
|
basic_auth=("elastic", os.getenv("PASSWORD")), |
|
verify_certs=False, |
|
request_timeout=30, |
|
|
|
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"}, |
|
) |
|
self.api_key = os.getenv("API_KEY") |
|
self.api_base = os.getenv("BASE_URL") |
|
|
|
def get_embedding(self, text: str) -> List[float]: |
|
"""调用SiliconFlow的embedding API获取向量""" |
|
headers = { |
|
"Authorization": f"Bearer {self.api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
response = requests.post( |
|
f"{self.api_base}/embeddings", |
|
headers=headers, |
|
json={ |
|
"model": "BAAI/bge-m3", |
|
"input": text |
|
} |
|
) |
|
|
|
if response.status_code == 200: |
|
return response.json()["data"][0]["embedding"] |
|
else: |
|
raise Exception(f"Error getting embedding: {response.text}") |
|
|
|
def store(self, documents: List[Dict], index_name: str) -> None: |
|
"""将文档存储到 Elasticsearch""" |
|
|
|
if not self.es.indices.exists(index=index_name): |
|
self.create_index(index_name) |
|
|
|
|
|
try: |
|
response = self.es.count(index=index_name) |
|
last_id = response['count'] - 1 |
|
if last_id < 0: |
|
last_id = -1 |
|
except Exception as e: |
|
print(f"获取文档数量时出错,假设为-1: {str(e)}") |
|
last_id = -1 |
|
|
|
|
|
bulk_data = [] |
|
for i, doc in enumerate(documents, start=last_id + 1): |
|
|
|
vector = self.get_embedding(doc['content']) |
|
|
|
|
|
bulk_data.append({ |
|
"index": { |
|
"_index": index_name, |
|
"_id": f"doc_{i}" |
|
} |
|
}) |
|
|
|
|
|
doc_data = { |
|
"content": doc['content'], |
|
"vector": vector, |
|
"metadata": { |
|
"file_name": doc['metadata'].get('file_name', '未知文件'), |
|
"source": doc['metadata'].get('source', ''), |
|
"page": doc['metadata'].get('page', ''), |
|
"img_url": doc['metadata'].get('img_url', '') |
|
} |
|
} |
|
bulk_data.append(doc_data) |
|
|
|
|
|
if bulk_data: |
|
response = self.es.bulk(operations=bulk_data, refresh=True) |
|
if response.get('errors'): |
|
print("批量写入时出现错误:", response) |
|
|
|
def get_files_in_index(self, index_name: str) -> List[str]: |
|
"""获取索引中的所有文件名""" |
|
try: |
|
response = self.es.search( |
|
index=index_name, |
|
body={ |
|
"size": 0, |
|
"aggs": { |
|
"unique_files": { |
|
"terms": { |
|
"field": "metadata.file_name", |
|
"size": 1000 |
|
} |
|
} |
|
} |
|
} |
|
) |
|
|
|
files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']] |
|
return sorted(files) |
|
except Exception as e: |
|
print(f"获取文件列表时出错: {str(e)}") |
|
return [] |
|
|
|
def create_index(self, index_name: str): |
|
"""创建 Elasticsearch 索引""" |
|
settings = { |
|
"mappings": { |
|
"properties": { |
|
"content": {"type": "text"}, |
|
"vector": { |
|
"type": "dense_vector", |
|
"dims": 1024 |
|
}, |
|
"metadata": { |
|
"properties": { |
|
"file_name": { |
|
"type": "keyword", |
|
"ignore_above": 256 |
|
}, |
|
"source": { |
|
"type": "keyword" |
|
}, |
|
"page": { |
|
"type": "keyword" |
|
}, |
|
"img_url": { |
|
"type": "keyword", |
|
"ignore_above": 2048 |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
if self.es.indices.exists(index=index_name): |
|
self.es.indices.delete(index=index_name) |
|
|
|
self.es.indices.create(index=index_name, body=settings) |
|
|
|
def delete_index(self, index_id: str) -> bool: |
|
"""删除一个索引""" |
|
try: |
|
if self.es.indices.exists(index=index_id): |
|
self.es.indices.delete(index=index_id) |
|
return True |
|
return False |
|
except Exception as e: |
|
print(f"删除索引时出错: {str(e)}") |
|
return False |
|
|
|
def delete_document(self, index_id: str, file_name: str) -> bool: |
|
"""根据文件名删除文档""" |
|
try: |
|
response = self.es.delete_by_query( |
|
index=index_id, |
|
body={ |
|
"query": { |
|
"term": { |
|
"metadata.file_name": file_name |
|
} |
|
} |
|
}, |
|
refresh=True |
|
) |
|
return True |
|
except Exception as e: |
|
print(f"删除文档时出错: {str(e)}") |
|
return False |