agent / modules /knowledge_base /vector_store.py
samlax12's picture
Upload 25 files
20f7a0a verified
from typing import List, Dict
import requests
import numpy as np
from elasticsearch import Elasticsearch
import urllib3
from dotenv import load_dotenv
import os
load_dotenv()
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class VectorStore:
def __init__(self):
# ES 8.x 的连接配置
self.es = Elasticsearch(
"https://samlax12-elastic.hf.space",
basic_auth=("elastic", os.getenv("PASSWORD")),
verify_certs=False,
request_timeout=30,
# 忽略系统索引警告
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"},
)
self.api_key = os.getenv("API_KEY")
self.api_base = os.getenv("BASE_URL")
def get_embedding(self, text: str) -> List[float]:
"""调用SiliconFlow的embedding API获取向量"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.api_base}/embeddings",
headers=headers,
json={
"model": "BAAI/bge-m3",
"input": text
}
)
if response.status_code == 200:
return response.json()["data"][0]["embedding"]
else:
raise Exception(f"Error getting embedding: {response.text}")
def store(self, documents: List[Dict], index_name: str) -> None:
"""将文档存储到 Elasticsearch"""
# 创建索引(如果不存在)
if not self.es.indices.exists(index=index_name):
self.create_index(index_name)
# 获取当前索引中的文档数量
try:
response = self.es.count(index=index_name)
last_id = response['count'] - 1 # 文档数量减1作为最后的ID
if last_id < 0:
last_id = -1
except Exception as e:
print(f"获取文档数量时出错,假设为-1: {str(e)}")
last_id = -1
# 批量索引文档
bulk_data = []
for i, doc in enumerate(documents, start=last_id + 1):
# 获取文档向量
vector = self.get_embedding(doc['content'])
# 准备索引数据
bulk_data.append({
"index": {
"_index": index_name,
"_id": f"doc_{i}"
}
})
# 构建文档数据,包含新的img_url字段
doc_data = {
"content": doc['content'],
"vector": vector,
"metadata": {
"file_name": doc['metadata'].get('file_name', '未知文件'),
"source": doc['metadata'].get('source', ''),
"page": doc['metadata'].get('page', ''),
"img_url": doc['metadata'].get('img_url', '') # 添加img_url字段
}
}
bulk_data.append(doc_data)
# 批量写入
if bulk_data:
response = self.es.bulk(operations=bulk_data, refresh=True)
if response.get('errors'):
print("批量写入时出现错误:", response)
def get_files_in_index(self, index_name: str) -> List[str]:
"""获取索引中的所有文件名"""
try:
response = self.es.search(
index=index_name,
body={
"size": 0,
"aggs": {
"unique_files": {
"terms": {
"field": "metadata.file_name",
"size": 1000
}
}
}
}
)
files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']]
return sorted(files)
except Exception as e:
print(f"获取文件列表时出错: {str(e)}")
return []
def create_index(self, index_name: str):
"""创建 Elasticsearch 索引"""
settings = {
"mappings": {
"properties": {
"content": {"type": "text"},
"vector": {
"type": "dense_vector",
"dims": 1024
},
"metadata": {
"properties": {
"file_name": {
"type": "keyword",
"ignore_above": 256
},
"source": {
"type": "keyword"
},
"page": {
"type": "keyword"
},
"img_url": { # 新增图片URL字段
"type": "keyword",
"ignore_above": 2048
}
}
}
}
}
}
# 如果索引已存在,先删除
if self.es.indices.exists(index=index_name):
self.es.indices.delete(index=index_name)
self.es.indices.create(index=index_name, body=settings)
def delete_index(self, index_id: str) -> bool:
"""删除一个索引"""
try:
if self.es.indices.exists(index=index_id):
self.es.indices.delete(index=index_id)
return True
return False
except Exception as e:
print(f"删除索引时出错: {str(e)}")
return False
def delete_document(self, index_id: str, file_name: str) -> bool:
"""根据文件名删除文档"""
try:
response = self.es.delete_by_query(
index=index_id,
body={
"query": {
"term": {
"metadata.file_name": file_name
}
}
},
refresh=True
)
return True
except Exception as e:
print(f"删除文档时出错: {str(e)}")
return False