Spaces:
Sleeping
Sleeping
Update src/rag.py
Browse files- src/rag.py +94 -95
src/rag.py
CHANGED
@@ -66,147 +66,146 @@ def get_embedding(text: str) -> list[float]:
|
|
66 |
embedding = embedding_model.encode(text).tolist()
|
67 |
return embedding
|
68 |
|
69 |
-
def find_similar_documents_hybrid_search(
|
70 |
-
query_vector:
|
71 |
search_query: str,
|
72 |
limit: int = 10,
|
73 |
candidates: int = 20,
|
74 |
-
vector_search_index: str = "embedding_search",
|
75 |
atlas_search_index: str = "header_text"
|
76 |
-
) ->
|
77 |
"""
|
78 |
-
|
|
|
79 |
"""
|
80 |
all_results = []
|
81 |
collection = load_mongo_collection()
|
82 |
-
|
83 |
-
|
84 |
try:
|
85 |
vector_pipeline = [
|
86 |
-
{
|
87 |
-
"
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
'_id': 1,
|
98 |
-
'header' : 1,
|
99 |
-
'content': 1,
|
100 |
-
"vector_score": {"$meta": "vectorSearchScore"}
|
101 |
-
}
|
102 |
-
}
|
103 |
]
|
104 |
-
|
105 |
vector_results = list(collection.aggregate(vector_pipeline))
|
106 |
-
safe_log_info(f"Vector search
|
107 |
for doc in vector_results:
|
108 |
-
|
109 |
-
doc['combined_score'] = doc.get('vector_score', 0) * 0.6
|
110 |
return vector_results
|
111 |
except Exception as e:
|
112 |
-
|
113 |
-
|
114 |
|
115 |
-
|
116 |
-
|
117 |
if not search_query or not search_query.strip():
|
118 |
return []
|
119 |
-
|
120 |
try:
|
121 |
text_pipeline = [
|
122 |
-
{
|
123 |
-
"
|
124 |
-
|
125 |
-
"
|
126 |
-
|
127 |
-
{
|
128 |
-
"text": {
|
129 |
-
"query": search_query,
|
130 |
-
"path": ["header", "content"]
|
131 |
-
}
|
132 |
-
}
|
133 |
-
]
|
134 |
-
}
|
135 |
-
}
|
136 |
-
},
|
137 |
-
{
|
138 |
-
"$project": {
|
139 |
-
'_id': 1,
|
140 |
-
'header': 1,
|
141 |
-
'content': 1,
|
142 |
-
"text_score": {"$meta": "searchScore"}
|
143 |
}
|
144 |
-
}
|
|
|
|
|
|
|
|
|
145 |
]
|
146 |
-
|
147 |
text_results = list(collection.aggregate(text_pipeline))
|
148 |
-
safe_log_info(f"Text search
|
149 |
for doc in text_results:
|
150 |
-
|
151 |
-
doc['combined_score'] = doc.get('text_score', 0) * 0.4
|
152 |
return text_results
|
153 |
except Exception as e:
|
154 |
-
safe_log_warning(f"Text search
|
155 |
return []
|
156 |
|
157 |
try:
|
158 |
-
#
|
159 |
start_time = time.time()
|
160 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
for future in as_completed(
|
166 |
try:
|
167 |
results = future.result()
|
168 |
all_results.extend(results)
|
169 |
except Exception as e:
|
170 |
-
safe_log_error(f"
|
171 |
|
172 |
search_time = time.time() - start_time
|
173 |
-
safe_log_info(f"
|
174 |
-
|
175 |
-
# 3. Merge và deduplicate results
|
176 |
-
seen_ids = set()
|
177 |
-
merged_results = []
|
178 |
|
|
|
|
|
179 |
for doc in all_results:
|
180 |
-
doc_id =
|
181 |
-
if doc_id not in
|
182 |
-
|
183 |
-
# Clean up the document for final result
|
184 |
-
final_doc = {
|
185 |
-
'_id': doc['_id'],
|
186 |
-
'content': doc.get('content', ''),
|
187 |
-
# 'uploader_username': doc.get('uploader_username', ''), # Removed
|
188 |
-
'header': doc.get('header', ''),
|
189 |
-
'score': doc.get('combined_score', 0)
|
190 |
-
}
|
191 |
-
merged_results.append(final_doc)
|
192 |
else:
|
193 |
-
#
|
194 |
-
|
195 |
-
if str(existing_doc['_id']) == doc_id:
|
196 |
-
existing_doc['score'] += doc.get('combined_score', 0) * 0.5
|
197 |
-
break
|
198 |
|
199 |
-
#
|
200 |
-
merged_results
|
|
|
|
|
|
|
201 |
|
202 |
-
# Return top results
|
203 |
final_results = merged_results[:limit]
|
204 |
-
safe_log_info(f"
|
205 |
|
206 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
except Exception as e:
|
209 |
-
safe_log_error(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
212 |
"""
|
|
|
66 |
embedding = embedding_model.encode(text).tolist()
|
67 |
return embedding
|
68 |
|
69 |
+
def find_similar_documents_hybrid_search(
|
70 |
+
query_vector: List[float],
|
71 |
search_query: str,
|
72 |
limit: int = 10,
|
73 |
candidates: int = 20,
|
74 |
+
vector_search_index: str = "embedding_search",
|
75 |
atlas_search_index: str = "header_text"
|
76 |
+
) -> List[Dict[str, Any]]:
|
77 |
"""
|
78 |
+
Thực hiện tìm kiếm hybrid kết hợp vector search và text search, chạy song song.
|
79 |
+
Bao gồm cơ chế fallback nếu tìm kiếm hybrid thất bại.
|
80 |
"""
|
81 |
all_results = []
|
82 |
collection = load_mongo_collection()
|
83 |
+
# Hàm con cho vector search
|
84 |
+
def perform_vector_search() -> list:
|
85 |
try:
|
86 |
vector_pipeline = [
|
87 |
+
{"$vectorSearch": {
|
88 |
+
"index": vector_search_index,
|
89 |
+
"path": "embedding",
|
90 |
+
"queryVector": query_vector,
|
91 |
+
"limit": limit,
|
92 |
+
"numCandidates": candidates
|
93 |
+
}},
|
94 |
+
{"$project": {
|
95 |
+
'_id': 1, 'header': 1, 'content': 1, 'uuid': 1,
|
96 |
+
"vector_score": {"$meta": "vectorSearchScore"}
|
97 |
+
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
]
|
|
|
99 |
vector_results = list(collection.aggregate(vector_pipeline))
|
100 |
+
safe_log_info(f"Vector search trả về {len(vector_results)} kết quả")
|
101 |
for doc in vector_results:
|
102 |
+
# Gán trọng số 0.7 cho điểm vector
|
103 |
+
doc['combined_score'] = doc.get('vector_score', 0) * 0.6
|
104 |
return vector_results
|
105 |
except Exception as e:
|
106 |
+
safe_log_warning(f"Vector search thất bại: {e}")
|
107 |
+
return []
|
108 |
|
109 |
+
# Hàm con cho text search
|
110 |
+
def perform_text_search() -> list:
|
111 |
if not search_query or not search_query.strip():
|
112 |
return []
|
|
|
113 |
try:
|
114 |
text_pipeline = [
|
115 |
+
{"$search": {
|
116 |
+
"index": atlas_search_index,
|
117 |
+
"text": { # Đơn giản hóa từ compound sang text nếu chỉ có một điều kiện
|
118 |
+
"query": search_query,
|
119 |
+
"path": ["header", "content"] # Thêm keywords vào path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
}
|
121 |
+
}},
|
122 |
+
{"$project": {
|
123 |
+
'_id': 1, 'header': 1, 'content': 1, 'uuid': 1, 'keywords': 1,
|
124 |
+
"text_score": {"$meta": "searchScore"}
|
125 |
+
}}
|
126 |
]
|
|
|
127 |
text_results = list(collection.aggregate(text_pipeline))
|
128 |
+
safe_log_info(f"Text search trả về {len(text_results)} kết quả")
|
129 |
for doc in text_results:
|
130 |
+
# Gán trọng số 0.3 cho điểm text search
|
131 |
+
doc['combined_score'] = doc.get('text_score', 0) * 0.4
|
132 |
return text_results
|
133 |
except Exception as e:
|
134 |
+
safe_log_warning(f"Text search thất bại: {e}")
|
135 |
return []
|
136 |
|
137 |
try:
|
138 |
+
# 1. Chạy song song hai truy vấn
|
139 |
start_time = time.time()
|
140 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
141 |
+
future_to_search = {
|
142 |
+
executor.submit(perform_vector_search): "vector",
|
143 |
+
executor.submit(perform_text_search): "text"
|
144 |
+
}
|
145 |
+
for future in as_completed(future_to_search):
|
146 |
try:
|
147 |
results = future.result()
|
148 |
all_results.extend(results)
|
149 |
except Exception as e:
|
150 |
+
safe_log_error(f"Lỗi trong quá trình tìm kiếm song song: {e}")
|
151 |
|
152 |
search_time = time.time() - start_time
|
153 |
+
safe_log_info(f"Tìm kiếm song song hoàn tất trong {search_time:.3f}s")
|
|
|
|
|
|
|
|
|
154 |
|
155 |
+
# 2. Hợp nhất và loại bỏ trùng lặp (Tối ưu hóa)
|
156 |
+
merged_map = {}
|
157 |
for doc in all_results:
|
158 |
+
doc_id = doc['_id']
|
159 |
+
if doc_id not in merged_map:
|
160 |
+
merged_map[doc_id] = doc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
else:
|
162 |
+
# Nếu tài liệu đã tồn tại, cộng dồn điểm số
|
163 |
+
merged_map[doc_id]['combined_score'] += doc['combined_score']
|
|
|
|
|
|
|
164 |
|
165 |
+
# Chuyển map thành list
|
166 |
+
merged_results = list(merged_map.values())
|
167 |
+
|
168 |
+
# 3. Sắp xếp theo điểm số tổng hợp
|
169 |
+
merged_results.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
|
170 |
|
|
|
171 |
final_results = merged_results[:limit]
|
172 |
+
safe_log_info(f"Tìm kiếm hybrid trả về: {len(final_results)} tài liệu")
|
173 |
|
174 |
+
return [{
|
175 |
+
'_id': r['_id'],
|
176 |
+
'header': r.get('header', ''),
|
177 |
+
'content': r.get('content', ''),
|
178 |
+
'uuid': r.get('uuid', ''),
|
179 |
+
'score': r.get('combined_score', 0)
|
180 |
+
} for r in final_results]
|
181 |
|
182 |
except Exception as e:
|
183 |
+
safe_log_error(f"Lỗi nghiêm trọng trong hàm hybrid search: {e}", exc_info=True)
|
184 |
+
|
185 |
+
# ----- PHẦN FALLBACK ĐÃ SỬA -----
|
186 |
+
safe_log_warning("Thực hiện fallback: chỉ tìm kiếm bằng Text Search.")
|
187 |
+
try:
|
188 |
+
# Thực hiện lại một truy vấn text search đơn giản
|
189 |
+
fallback_pipeline = [
|
190 |
+
{"$search": {
|
191 |
+
"index": atlas_search_index,
|
192 |
+
"text": {
|
193 |
+
"query": search_query,
|
194 |
+
"path": ["header", "content", "keywords"]
|
195 |
+
}
|
196 |
+
}},
|
197 |
+
{"$project": {
|
198 |
+
'_id': 1, 'header': 1, 'content': 1, 'uuid': 1,
|
199 |
+
'score': {"$meta": "searchScore"}
|
200 |
+
}},
|
201 |
+
{"$limit": limit}
|
202 |
+
]
|
203 |
+
fallback_results = list(collection.aggregate(fallback_pipeline))
|
204 |
+
safe_log_info(f"Fallback search trả về {len(fallback_results)} kết quả.")
|
205 |
+
return fallback_results
|
206 |
+
except Exception as fallback_e:
|
207 |
+
safe_log_error(f"Fallback search cũng thất bại: {fallback_e}", exc_info=True)
|
208 |
+
return [] # Trả về list rỗng nếu cả fallback cũng lỗi
|
209 |
|
210 |
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
211 |
"""
|