hoangchihien3011 commited on
Commit
3da6db4
·
verified ·
1 Parent(s): 4550eaf

Update src/rag.py

Browse files
Files changed (1) hide show
  1. src/rag.py +94 -95
src/rag.py CHANGED
@@ -66,147 +66,146 @@ def get_embedding(text: str) -> list[float]:
66
  embedding = embedding_model.encode(text).tolist()
67
  return embedding
68
 
69
- def find_similar_documents_hybrid_search(
70
- query_vector: list[float],
71
  search_query: str,
72
  limit: int = 10,
73
  candidates: int = 20,
74
- vector_search_index: str = "embedding_search",
75
  atlas_search_index: str = "header_text"
76
- ) -> list[dict]:
77
  """
78
- Hybrid search combining vector and text search with parallel execution.
 
79
  """
80
  all_results = []
81
  collection = load_mongo_collection()
82
- def perform_vector_search():
83
- """Perform vector search in parallel."""
84
  try:
85
  vector_pipeline = [
86
- {
87
- "$vectorSearch": {
88
- "index": vector_search_index,
89
- "path": "embedding",
90
- "queryVector": query_vector,
91
- "limit": limit,
92
- "numCandidates": candidates
93
- }
94
- },
95
- {
96
- "$project": {
97
- '_id': 1,
98
- 'header' : 1,
99
- 'content': 1,
100
- "vector_score": {"$meta": "vectorSearchScore"}
101
- }
102
- }
103
  ]
104
-
105
  vector_results = list(collection.aggregate(vector_pipeline))
106
- safe_log_info(f"Vector search returned {len(vector_results)} results")
107
  for doc in vector_results:
108
- doc['search_type'] = 'vector'
109
- doc['combined_score'] = doc.get('vector_score', 0) * 0.6 # Weight vector score
110
  return vector_results
111
  except Exception as e:
112
- safe_log_warning(f"Vector search failed: {e}")
113
- return []
114
 
115
- def perform_text_search():
116
- """Perform text search in parallel."""
117
  if not search_query or not search_query.strip():
118
  return []
119
-
120
  try:
121
  text_pipeline = [
122
- {
123
- "$search": {
124
- "index": atlas_search_index,
125
- "compound": {
126
- "must": [
127
- {
128
- "text": {
129
- "query": search_query,
130
- "path": ["header", "content"]
131
- }
132
- }
133
- ]
134
- }
135
- }
136
- },
137
- {
138
- "$project": {
139
- '_id': 1,
140
- 'header': 1,
141
- 'content': 1,
142
- "text_score": {"$meta": "searchScore"}
143
  }
144
- }
 
 
 
 
145
  ]
146
-
147
  text_results = list(collection.aggregate(text_pipeline))
148
- safe_log_info(f"Text search returned {len(text_results)} results")
149
  for doc in text_results:
150
- doc['search_type'] = 'text'
151
- doc['combined_score'] = doc.get('text_score', 0) * 0.4 # Weight text score
152
  return text_results
153
  except Exception as e:
154
- safe_log_warning(f"Text search failed: {e}")
155
  return []
156
 
157
  try:
158
- # Run both searches in parallel
159
  start_time = time.time()
160
  with ThreadPoolExecutor(max_workers=2) as executor:
161
- vector_future = executor.submit(perform_vector_search)
162
- text_future = executor.submit(perform_text_search)
163
-
164
- # Collect results as they complete
165
- for future in as_completed([vector_future, text_future]):
166
  try:
167
  results = future.result()
168
  all_results.extend(results)
169
  except Exception as e:
170
- safe_log_error(f"Error in parallel search: {e}")
171
 
172
  search_time = time.time() - start_time
173
- safe_log_info(f"Parallel search completed in {search_time:.3f}s")
174
-
175
- # 3. Merge và deduplicate results
176
- seen_ids = set()
177
- merged_results = []
178
 
 
 
179
  for doc in all_results:
180
- doc_id = str(doc['_id'])
181
- if doc_id not in seen_ids:
182
- seen_ids.add(doc_id)
183
- # Clean up the document for final result
184
- final_doc = {
185
- '_id': doc['_id'],
186
- 'content': doc.get('content', ''),
187
- # 'uploader_username': doc.get('uploader_username', ''), # Removed
188
- 'header': doc.get('header', ''),
189
- 'score': doc.get('combined_score', 0)
190
- }
191
- merged_results.append(final_doc)
192
  else:
193
- # If document already exists, boost its score
194
- for existing_doc in merged_results:
195
- if str(existing_doc['_id']) == doc_id:
196
- existing_doc['score'] += doc.get('combined_score', 0) * 0.5
197
- break
198
 
199
- # Sort by combined score
200
- merged_results.sort(key=lambda x: x.get('score', 0), reverse=True)
 
 
 
201
 
202
- # Return top results
203
  final_results = merged_results[:limit]
204
- safe_log_info(f"Hybrid search final results: {len(final_results)} documents")
205
 
206
- return final_results
 
 
 
 
 
 
207
 
208
  except Exception as e:
209
- safe_log_error(f"Error in hybrid search: {e}", exc_info=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
212
  """
 
66
  embedding = embedding_model.encode(text).tolist()
67
  return embedding
68
 
69
+ def find_similar_documents_hybrid_search(
70
+ query_vector: List[float],
71
  search_query: str,
72
  limit: int = 10,
73
  candidates: int = 20,
74
+ vector_search_index: str = "embedding_search",
75
  atlas_search_index: str = "header_text"
76
+ ) -> List[Dict[str, Any]]:
77
  """
78
+ Thực hiện tìm kiếm hybrid kết hợp vector search text search, chạy song song.
79
+ Bao gồm cơ chế fallback nếu tìm kiếm hybrid thất bại.
80
  """
81
  all_results = []
82
  collection = load_mongo_collection()
83
+ # Hàm con cho vector search
84
+ def perform_vector_search() -> list:
85
  try:
86
  vector_pipeline = [
87
+ {"$vectorSearch": {
88
+ "index": vector_search_index,
89
+ "path": "embedding",
90
+ "queryVector": query_vector,
91
+ "limit": limit,
92
+ "numCandidates": candidates
93
+ }},
94
+ {"$project": {
95
+ '_id': 1, 'header': 1, 'content': 1, 'uuid': 1,
96
+ "vector_score": {"$meta": "vectorSearchScore"}
97
+ }}
 
 
 
 
 
 
98
  ]
 
99
  vector_results = list(collection.aggregate(vector_pipeline))
100
+ safe_log_info(f"Vector search trả về {len(vector_results)} kết quả")
101
  for doc in vector_results:
102
+ # Gán trọng số 0.7 cho điểm vector
103
+ doc['combined_score'] = doc.get('vector_score', 0) * 0.6
104
  return vector_results
105
  except Exception as e:
106
+ safe_log_warning(f"Vector search thất bại: {e}")
107
+ return []
108
 
109
+ # Hàm con cho text search
110
+ def perform_text_search() -> list:
111
  if not search_query or not search_query.strip():
112
  return []
 
113
  try:
114
  text_pipeline = [
115
+ {"$search": {
116
+ "index": atlas_search_index,
117
+ "text": { # Đơn giản hóa từ compound sang text nếu chỉ có một điều kiện
118
+ "query": search_query,
119
+ "path": ["header", "content"] # Thêm keywords vào path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  }
121
+ }},
122
+ {"$project": {
123
+ '_id': 1, 'header': 1, 'content': 1, 'uuid': 1, 'keywords': 1,
124
+ "text_score": {"$meta": "searchScore"}
125
+ }}
126
  ]
 
127
  text_results = list(collection.aggregate(text_pipeline))
128
+ safe_log_info(f"Text search trả về {len(text_results)} kết quả")
129
  for doc in text_results:
130
+ # Gán trọng số 0.3 cho điểm text search
131
+ doc['combined_score'] = doc.get('text_score', 0) * 0.4
132
  return text_results
133
  except Exception as e:
134
+ safe_log_warning(f"Text search thất bại: {e}")
135
  return []
136
 
137
  try:
138
+ # 1. Chạy song song hai truy vấn
139
  start_time = time.time()
140
  with ThreadPoolExecutor(max_workers=2) as executor:
141
+ future_to_search = {
142
+ executor.submit(perform_vector_search): "vector",
143
+ executor.submit(perform_text_search): "text"
144
+ }
145
+ for future in as_completed(future_to_search):
146
  try:
147
  results = future.result()
148
  all_results.extend(results)
149
  except Exception as e:
150
+ safe_log_error(f"Lỗi trong quá trình tìm kiếm song song: {e}")
151
 
152
  search_time = time.time() - start_time
153
+ safe_log_info(f"Tìm kiếm song song hoàn tất trong {search_time:.3f}s")
 
 
 
 
154
 
155
+ # 2. Hợp nhất và loại bỏ trùng lặp (Tối ưu hóa)
156
+ merged_map = {}
157
  for doc in all_results:
158
+ doc_id = doc['_id']
159
+ if doc_id not in merged_map:
160
+ merged_map[doc_id] = doc
 
 
 
 
 
 
 
 
 
161
  else:
162
+ # Nếu tài liệu đã tồn tại, cộng dồn điểm số
163
+ merged_map[doc_id]['combined_score'] += doc['combined_score']
 
 
 
164
 
165
+ # Chuyển map thành list
166
+ merged_results = list(merged_map.values())
167
+
168
+ # 3. Sắp xếp theo điểm số tổng hợp
169
+ merged_results.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
170
 
 
171
  final_results = merged_results[:limit]
172
+ safe_log_info(f"Tìm kiếm hybrid trả về: {len(final_results)} tài liệu")
173
 
174
+ return [{
175
+ '_id': r['_id'],
176
+ 'header': r.get('header', ''),
177
+ 'content': r.get('content', ''),
178
+ 'uuid': r.get('uuid', ''),
179
+ 'score': r.get('combined_score', 0)
180
+ } for r in final_results]
181
 
182
  except Exception as e:
183
+ safe_log_error(f"Lỗi nghiêm trọng trong hàm hybrid search: {e}", exc_info=True)
184
+
185
+ # ----- PHẦN FALLBACK ĐÃ SỬA -----
186
+ safe_log_warning("Thực hiện fallback: chỉ tìm kiếm bằng Text Search.")
187
+ try:
188
+ # Thực hiện lại một truy vấn text search đơn giản
189
+ fallback_pipeline = [
190
+ {"$search": {
191
+ "index": atlas_search_index,
192
+ "text": {
193
+ "query": search_query,
194
+ "path": ["header", "content", "keywords"]
195
+ }
196
+ }},
197
+ {"$project": {
198
+ '_id': 1, 'header': 1, 'content': 1, 'uuid': 1,
199
+ 'score': {"$meta": "searchScore"}
200
+ }},
201
+ {"$limit": limit}
202
+ ]
203
+ fallback_results = list(collection.aggregate(fallback_pipeline))
204
+ safe_log_info(f"Fallback search trả về {len(fallback_results)} kết quả.")
205
+ return fallback_results
206
+ except Exception as fallback_e:
207
+ safe_log_error(f"Fallback search cũng thất bại: {fallback_e}", exc_info=True)
208
+ return [] # Trả về list rỗng nếu cả fallback cũng lỗi
209
 
210
  def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
211
  """