File size: 6,720 Bytes
20f7a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1625bb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import List, Dict
import requests
import numpy as np
from elasticsearch import Elasticsearch
import urllib3
from dotenv import load_dotenv
import os

load_dotenv()

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

class VectorStore:
    def __init__(self):
        # ES 8.x 的连接配置
        self.es = Elasticsearch(
            "https://samlax12-elastic.hf.space",
            basic_auth=("elastic", os.getenv("PASSWORD")),
            verify_certs=False,
            request_timeout=30,
            # 忽略系统索引警告
            headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"},
        )
        self.api_key = os.getenv("API_KEY")
        self.api_base = os.getenv("BASE_URL")
        
    def get_embedding(self, text: str) -> List[float]:
        """调用SiliconFlow的embedding API获取向量"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        response = requests.post(
            f"{self.api_base}/embeddings",
            headers=headers,
            json={
                "model": "BAAI/bge-m3",
                "input": text
            }
        )
        
        if response.status_code == 200:
            return response.json()["data"][0]["embedding"]
        else:
            raise Exception(f"Error getting embedding: {response.text}")
    
    def store(self, documents: List[Dict], index_name: str) -> None:
        """将文档存储到 Elasticsearch"""
        # 创建索引(如果不存在)
        if not self.es.indices.exists(index=index_name):
            self.create_index(index_name)
        
        # 获取当前索引中的文档数量
        try:
            response = self.es.count(index=index_name)
            last_id = response['count'] - 1  # 文档数量减1作为最后的ID
            if last_id < 0:
                last_id = -1
        except Exception as e:
            print(f"获取文档数量时出错,假设为-1: {str(e)}")
            last_id = -1
        
        # 批量索引文档
        bulk_data = []
        for i, doc in enumerate(documents, start=last_id + 1):
            # 获取文档向量
            vector = self.get_embedding(doc['content'])
            
            # 准备索引数据
            bulk_data.append({
                "index": {
                    "_index": index_name,
                    "_id": f"doc_{i}"
                }
            })
            
            # 构建文档数据,包含新的img_url字段
            doc_data = {
                "content": doc['content'],
                "vector": vector,
                "metadata": {
                    "file_name": doc['metadata'].get('file_name', '未知文件'),
                    "source": doc['metadata'].get('source', ''),
                    "page": doc['metadata'].get('page', ''),
                    "img_url": doc['metadata'].get('img_url', '')  # 添加img_url字段
                }
            }
            bulk_data.append(doc_data)
            
        # 批量写入
        if bulk_data:
            response = self.es.bulk(operations=bulk_data, refresh=True)
            if response.get('errors'):
                print("批量写入时出现错误:", response)
    
    def get_files_in_index(self, index_name: str) -> List[str]:
        """获取索引中的所有文件名"""
        try:
            response = self.es.search(
                index=index_name,
                body={
                    "size": 0,
                    "aggs": {
                        "unique_files": {
                            "terms": {
                                "field": "metadata.file_name",
                                "size": 1000
                            }
                        }
                    }
                }
            )
            
            files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']]
            return sorted(files)
        except Exception as e:
            print(f"获取文件列表时出错: {str(e)}")
            return []

    def create_index(self, index_name: str):
        """创建 Elasticsearch 索引"""
        settings = {
            "mappings": {
                "properties": {
                    "content": {"type": "text"},
                    "vector": {
                        "type": "dense_vector",
                        "dims": 1024
                    },
                    "metadata": {
                        "properties": {
                            "file_name": {
                                "type": "keyword",
                                "ignore_above": 256
                            },
                            "source": {
                                "type": "keyword"
                            },
                            "page": {
                                "type": "keyword"
                            },
                            "img_url": {  # 新增图片URL字段
                                "type": "keyword",
                                "ignore_above": 2048
                            }
                        }
                    }
                }
            }
        }
        
        # 如果索引已存在,先删除
        if self.es.indices.exists(index=index_name):
            self.es.indices.delete(index=index_name)
        
        self.es.indices.create(index=index_name, body=settings)
        
    def delete_index(self, index_id: str) -> bool:
        """删除一个索引"""
        try:
            if self.es.indices.exists(index=index_id):
                self.es.indices.delete(index=index_id)
                return True
            return False
        except Exception as e:
            print(f"删除索引时出错: {str(e)}")
            return False
            
    def delete_document(self, index_id: str, file_name: str) -> bool:
        """根据文件名删除文档"""
        try:
            response = self.es.delete_by_query(
                index=index_id,
                body={
                    "query": {
                        "term": {
                            "metadata.file_name": file_name
                        }
                    }
                },
                refresh=True
            )
            return True
        except Exception as e:
            print(f"删除文档时出错: {str(e)}")
            return False