Phoenix21 commited on
Commit
936bed3
·
verified ·
1 Parent(s): 341f0a0

modified to do websearch and increase content and also refusal proof

Browse files
Files changed (1) hide show
  1. pipeline.py +171 -30
pipeline.py CHANGED
@@ -2,13 +2,16 @@ import os
2
  import getpass
3
  import spacy
4
  import pandas as pd
 
5
  from typing import Optional, List, Dict, Any
6
  import subprocess
 
7
  from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
 
12
  from smolagents import DuckDuckGoSearchTool, ManagedAgent
13
  from pydantic import BaseModel, Field, ValidationError, validator
14
  from mistralai import Mistral
@@ -27,7 +30,7 @@ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
27
  client = Mistral(api_key=mistral_api_key)
28
 
29
  # Setup ChatGoogleGenerativeAI for Gemini
30
- # Ensure GOOGLE_API_KEY is set in your environment variables.
31
  gemini_llm = ChatGoogleGenerativeAI(
32
  model="gemini-1.5-pro",
33
  temperature=0.5,
@@ -36,16 +39,9 @@ gemini_llm = ChatGoogleGenerativeAI(
36
  # Additional parameters or safety_settings can be added here if needed
37
  )
38
 
39
- # Initialize ManagedAgent for web search using Gemini
40
- # pydantic_agent = ManagedAgent(
41
- # llm=ChatGoogleGenerativeAI(
42
- # model="gemini-1.5-pro",
43
- # temperature=0.5,
44
- # max_retries=2,
45
- # google_api_key=os.environ.get("GEMINI_API_KEY"),
46
- # ),
47
- # tools=[DuckDuckGoSearchTool()]
48
- # )
49
 
50
  class QueryInput(BaseModel):
51
  query: str = Field(..., min_length=1, description="The input query string")
@@ -63,6 +59,10 @@ class ModerationResult(BaseModel):
63
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
64
  original_text: str = Field(..., description="The original input text")
65
 
 
 
 
 
66
  def install_spacy_model():
67
  try:
68
  spacy.load("en_core_web_sm")
@@ -75,6 +75,10 @@ def install_spacy_model():
75
  install_spacy_model()
76
  nlp = spacy.load("en_core_web_sm")
77
 
 
 
 
 
78
  def sanitize_message(message: Any) -> str:
79
  """Sanitize message input to ensure it's a valid string."""
80
  try:
@@ -92,16 +96,19 @@ def sanitize_message(message: Any) -> str:
92
  raise RuntimeError(f"Error in sanitize function: {str(e)}")
93
 
94
  def extract_main_topic(query: str) -> str:
 
95
  try:
96
  query_input = QueryInput(query=query)
97
  doc = nlp(query_input.query)
98
  main_topic = None
99
 
 
100
  for ent in doc.ents:
101
  if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
102
  main_topic = ent.text
103
  break
104
 
 
105
  if not main_topic:
106
  for token in doc:
107
  if token.pos_ in ["NOUN", "PROPN"]:
@@ -114,6 +121,7 @@ def extract_main_topic(query: str) -> str:
114
  return "this topic"
115
 
116
  def moderate_text(query: str) -> ModerationResult:
 
117
  try:
118
  query_input = QueryInput(query=query)
119
 
@@ -132,6 +140,7 @@ def moderate_text(query: str) -> ModerationResult:
132
  "dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False),
133
  "selfharm": response.results[0].categories.get("selfharm", False)
134
  }
 
135
  is_safe = not any(categories.values())
136
 
137
  return ModerationResult(
@@ -145,13 +154,16 @@ def moderate_text(query: str) -> ModerationResult:
145
  raise RuntimeError(f"Moderation failed: {str(e)}")
146
 
147
  def classify_query(query: str) -> str:
 
148
  try:
149
  query_input = QueryInput(query=query)
150
 
 
151
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
152
  if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
153
  return "Wellness"
154
 
 
155
  class_result = classification_chain.invoke({"query": query_input.query})
156
  classification = class_result.get("text", "").strip()
157
 
@@ -161,7 +173,14 @@ def classify_query(query: str) -> str:
161
  except Exception as e:
162
  raise RuntimeError(f"Classification failed: {str(e)}")
163
 
 
 
 
 
164
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
 
 
 
165
  try:
166
  if os.path.exists(store_dir):
167
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
@@ -173,18 +192,22 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
173
  df = pd.read_csv(csv_path)
174
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
175
  df.columns = df.columns.str.strip()
 
 
176
  if "Answer" in df.columns:
177
  df.rename(columns={"Answer": "Answers"}, inplace=True)
178
  if "Question" not in df.columns and "Question " in df.columns:
179
  df.rename(columns={"Question ": "Question"}, inplace=True)
180
  if "Question" not in df.columns or "Answers" not in df.columns:
181
  raise ValueError("CSV must have 'Question' and 'Answers' columns.")
 
182
  docs = []
183
  for _, row in df.iterrows():
184
  q = str(row["Question"])
185
  ans = str(row["Answers"])
186
  doc = Document(page_content=ans, metadata={"question": q})
187
  docs.append(doc)
 
188
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
189
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
190
  vectorstore.save_local(store_dir)
@@ -194,11 +217,11 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
194
  raise RuntimeError(f"Error building/loading vector store: {str(e)}")
195
 
196
  def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
197
- """Build RAG chain using the Gemini LLM directly without a custom class."""
198
  try:
199
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
200
  chain = RetrievalQA.from_chain_type(
201
- llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance
202
  chain_type="stuff",
203
  retriever=retriever,
204
  return_source_documents=True
@@ -206,18 +229,107 @@ def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
206
  return chain
207
  except Exception as e:
208
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- def do_web_search(query: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  try:
 
212
  search_tool = DuckDuckGoSearchTool()
213
  search_agent = ManagedAgent(llm=gemini_llm, tools=[search_tool])
214
- search_result = search_agent.run(f"Search for information about: {query}")
215
- return str(search_result).strip()
 
 
 
216
  except Exception as e:
217
  print(f"Web search failed: {e}")
218
  return ""
219
 
 
 
 
 
220
  def merge_responses(csv_answer: str, web_answer: str) -> str:
 
221
  try:
222
  if not csv_answer and not web_answer:
223
  return "I apologize, but I couldn't find any relevant information."
@@ -233,7 +345,18 @@ def merge_responses(csv_answer: str, web_answer: str) -> str:
233
  print(f"Error merging responses: {e}")
234
  return csv_answer or web_answer or "I apologize, but I couldn't process the information properly."
235
 
 
 
 
 
236
  def run_pipeline(query: str) -> str:
 
 
 
 
 
 
 
237
  try:
238
  print(query)
239
  sanitized_query = sanitize_message(query)
@@ -242,34 +365,44 @@ def run_pipeline(query: str) -> str:
242
  topic = extract_main_topic(query_input.query)
243
  moderation_result = moderate_text(query_input.query)
244
 
 
245
  if not moderation_result.is_safe:
246
  return "Sorry, this query contains harmful or inappropriate content."
247
-
 
248
  classification = classify_query(moderation_result.original_text)
249
 
 
250
  if classification == "OutOfScope":
251
  refusal_text = refusal_chain.run({"topic": topic})
252
  return tailor_chain.run({"response": refusal_text}).strip()
253
-
 
254
  if classification == "Wellness":
 
255
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
256
- if isinstance(rag_result, dict) and "result" in rag_result:
257
- csv_answer = str(rag_result["result"]).strip()
258
- else:
259
- csv_answer = str(rag_result).strip()
260
- web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
 
261
  final_merged = merge_responses(csv_answer, web_answer)
262
  return tailor_chain.run({"response": final_merged}).strip()
263
 
264
  if classification == "Brand":
 
265
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
266
- if isinstance(rag_result, dict) and "result" in rag_result:
267
- csv_answer = str(rag_result["result"]).strip()
268
- else:
269
- csv_answer = str(rag_result).strip()
270
- final_merged = merge_responses(csv_answer, "")
 
 
271
  return tailor_chain.run({"response": final_merged}).strip()
272
 
 
273
  refusal_text = refusal_chain.run({"topic": topic})
274
  return tailor_chain.run({"response": refusal_text}).strip()
275
 
@@ -279,27 +412,35 @@ def run_pipeline(query: str) -> str:
279
  raise RuntimeError(f"Error in run_pipeline: {str(e)}")
280
 
281
  def run_with_chain(query: str) -> str:
 
282
  try:
283
  return run_pipeline(query)
284
  except Exception as e:
285
  print(f"Error in run_with_chain: {str(e)}")
286
  return "I apologize, but I encountered an error processing your request. Please try again."
287
 
288
- # Initialize chains and vectorstores
 
 
 
 
289
  classification_chain = get_classification_chain()
290
  refusal_chain = get_refusal_chain()
291
  tailor_chain = get_tailor_chain()
292
  cleaner_chain = get_cleaner_chain()
293
 
 
294
  wellness_csv = "AIChatbot.csv"
295
  brand_csv = "BrandAI.csv"
296
  wellness_store_dir = "faiss_wellness_store"
297
  brand_store_dir = "faiss_brand_store"
298
 
 
299
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
300
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
301
 
 
302
  wellness_rag_chain = build_rag_chain(wellness_vectorstore)
303
  brand_rag_chain = build_rag_chain(brand_vectorstore)
304
 
305
- print("Pipeline initialized successfully!")
 
2
  import getpass
3
  import spacy
4
  import pandas as pd
5
+ import numpy as np
6
  from typing import Optional, List, Dict, Any
7
  import subprocess
8
+
9
  from langchain.llms.base import LLM
10
  from langchain.docstore.document import Document
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.vectorstores import FAISS
13
  from langchain.chains import RetrievalQA
14
+
15
  from smolagents import DuckDuckGoSearchTool, ManagedAgent
16
  from pydantic import BaseModel, Field, ValidationError, validator
17
  from mistralai import Mistral
 
30
  client = Mistral(api_key=mistral_api_key)
31
 
32
  # Setup ChatGoogleGenerativeAI for Gemini
33
+ # Ensure GEMINI_API_KEY is set in your environment variables.
34
  gemini_llm = ChatGoogleGenerativeAI(
35
  model="gemini-1.5-pro",
36
  temperature=0.5,
 
39
  # Additional parameters or safety_settings can be added here if needed
40
  )
41
 
42
+ ################################################################################
43
+ # Pydantic Models
44
+ ################################################################################
 
 
 
 
 
 
 
45
 
46
  class QueryInput(BaseModel):
47
  query: str = Field(..., min_length=1, description="The input query string")
 
59
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
60
  original_text: str = Field(..., description="The original input text")
61
 
62
+ ################################################################################
63
+ # SPACy Setup
64
+ ################################################################################
65
+
66
  def install_spacy_model():
67
  try:
68
  spacy.load("en_core_web_sm")
 
75
  install_spacy_model()
76
  nlp = spacy.load("en_core_web_sm")
77
 
78
+ ################################################################################
79
+ # Utility Functions
80
+ ################################################################################
81
+
82
  def sanitize_message(message: Any) -> str:
83
  """Sanitize message input to ensure it's a valid string."""
84
  try:
 
96
  raise RuntimeError(f"Error in sanitize function: {str(e)}")
97
 
98
  def extract_main_topic(query: str) -> str:
99
+ """Extracts a main topic (named entity or noun) from the user query."""
100
  try:
101
  query_input = QueryInput(query=query)
102
  doc = nlp(query_input.query)
103
  main_topic = None
104
 
105
+ # Attempt to find an entity
106
  for ent in doc.ents:
107
  if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
108
  main_topic = ent.text
109
  break
110
 
111
+ # If no named entity, fall back to nouns or proper nouns
112
  if not main_topic:
113
  for token in doc:
114
  if token.pos_ in ["NOUN", "PROPN"]:
 
121
  return "this topic"
122
 
123
  def moderate_text(query: str) -> ModerationResult:
124
+ """Uses Mistral's moderation to determine if the content is safe."""
125
  try:
126
  query_input = QueryInput(query=query)
127
 
 
140
  "dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False),
141
  "selfharm": response.results[0].categories.get("selfharm", False)
142
  }
143
+ # If any flagged category is True, then not safe
144
  is_safe = not any(categories.values())
145
 
146
  return ModerationResult(
 
154
  raise RuntimeError(f"Moderation failed: {str(e)}")
155
 
156
  def classify_query(query: str) -> str:
157
+ """Classify user query into known categories using your classification chain."""
158
  try:
159
  query_input = QueryInput(query=query)
160
 
161
+ # Quick pattern-based approach for 'Wellness'
162
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
163
  if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
164
  return "Wellness"
165
 
166
+ # Use chain for everything else
167
  class_result = classification_chain.invoke({"query": query_input.query})
168
  classification = class_result.get("text", "").strip()
169
 
 
173
  except Exception as e:
174
  raise RuntimeError(f"Classification failed: {str(e)}")
175
 
176
+ ################################################################################
177
+ # Vector Store Building/Loading
178
+ ################################################################################
179
+
180
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
181
+ """
182
+ Builds or loads a FAISS vector store for CSV documents containing 'Question' and 'Answers'.
183
+ """
184
  try:
185
  if os.path.exists(store_dir):
186
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
 
192
  df = pd.read_csv(csv_path)
193
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
194
  df.columns = df.columns.str.strip()
195
+
196
+ # Fix possible column name variations
197
  if "Answer" in df.columns:
198
  df.rename(columns={"Answer": "Answers"}, inplace=True)
199
  if "Question" not in df.columns and "Question " in df.columns:
200
  df.rename(columns={"Question ": "Question"}, inplace=True)
201
  if "Question" not in df.columns or "Answers" not in df.columns:
202
  raise ValueError("CSV must have 'Question' and 'Answers' columns.")
203
+
204
  docs = []
205
  for _, row in df.iterrows():
206
  q = str(row["Question"])
207
  ans = str(row["Answers"])
208
  doc = Document(page_content=ans, metadata={"question": q})
209
  docs.append(doc)
210
+
211
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
212
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
213
  vectorstore.save_local(store_dir)
 
217
  raise RuntimeError(f"Error building/loading vector store: {str(e)}")
218
 
219
  def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
220
+ """Build RAG chain using the Gemini LLM."""
221
  try:
222
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
223
  chain = RetrievalQA.from_chain_type(
224
+ llm=gemini_llm,
225
  chain_type="stuff",
226
  retriever=retriever,
227
  return_source_documents=True
 
229
  return chain
230
  except Exception as e:
231
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
232
+
233
+ ################################################################################
234
+ # Web Search Caching: Separate FAISS Vector Store
235
+ ################################################################################
236
+
237
+ # Directory for storing cached web search results
238
+ web_search_store_dir = "faiss_websearch_store"
239
+
240
+ def build_or_load_websearch_store(store_dir: str) -> FAISS:
241
+ """
242
+ Builds or loads a FAISS vector store for caching web search results.
243
+ Each Document will have page_content as the search result text,
244
+ and metadata={"question": <user_query>}.
245
+ """
246
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
247
+ if os.path.exists(store_dir):
248
+ print(f"DEBUG: Found existing WebSearch FAISS store at '{store_dir}'. Loading...")
249
+ return FAISS.load_local(store_dir, embeddings)
250
+ else:
251
+ print(f"DEBUG: Creating a new, empty WebSearch FAISS store at '{store_dir}'...")
252
+ # Start empty
253
+ empty_store = FAISS.from_texts([""], embeddings, metadatas=[{"question": "placeholder"}])
254
+ # Remove the placeholder doc so we don't retrieve it
255
+ empty_store.index.reset()
256
+ empty_store.docstore._dict = {}
257
+ empty_store.save_local(store_dir)
258
+ return empty_store
259
+
260
+ # Initialize the web search vector store
261
+ web_search_vectorstore = build_or_load_websearch_store(web_search_store_dir)
262
+ websearch_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
263
+
264
+ def compute_cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
265
+ """Compute cosine similarity between two embedding vectors."""
266
+ a = np.array(vec_a, dtype=float)
267
+ b = np.array(vec_b, dtype=float)
268
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))
269
+
270
+ def get_cached_websearch(query: str, threshold: float = 0.8) -> Optional[str]:
271
+ """
272
+ Attempts to retrieve a cached web search result for a given query.
273
+ If the top retrieved document has a cosine similarity >= threshold,
274
+ returns that document's page_content. Otherwise, returns None.
275
+ """
276
+ # Retrieve the top doc from the store
277
+ retriever = web_search_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 1})
278
+ results = retriever.get_relevant_documents(query)
279
+ if not results:
280
+ return None
281
+
282
+ # Compare similarity with the top doc
283
+ top_doc = results[0]
284
+ query_vec = websearch_embeddings.embed_query(query)
285
+ doc_vec = websearch_embeddings.embed_query(top_doc.page_content)
286
+ similarity = compute_cosine_similarity(query_vec, doc_vec)
287
+
288
+ if similarity >= threshold:
289
+ print(f"DEBUG: Using cached web search (similarity={similarity:.2f} >= {threshold})")
290
+ return top_doc.page_content
291
 
292
+ print(f"DEBUG: Cached doc similarity={similarity:.2f} < {threshold}, not reusing.")
293
+ return None
294
+
295
+ def store_websearch_result(query: str, web_search_text: str):
296
+ """
297
+ Embeds and stores the web search result text in the web search vector store,
298
+ keyed by the question in metadata. Then saves the store locally.
299
+ """
300
+ if not web_search_text.strip():
301
+ return # Don't store empty results
302
+ doc = Document(page_content=web_search_text, metadata={"question": query})
303
+ web_search_vectorstore.add_documents([doc], embedding=websearch_embeddings)
304
+ web_search_vectorstore.save_local(web_search_store_dir)
305
+
306
+ def do_cached_web_search(query: str) -> str:
307
+ """Perform a DuckDuckGo web search, but with caching via FAISS vector store."""
308
+ # 1) Check cache
309
+ cached_result = get_cached_websearch(query)
310
+ if cached_result:
311
+ return cached_result
312
+
313
+ # 2) If no suitable cached answer, do a new search
314
  try:
315
+ print("DEBUG: Performing a new web search...")
316
  search_tool = DuckDuckGoSearchTool()
317
  search_agent = ManagedAgent(llm=gemini_llm, tools=[search_tool])
318
+ new_search_result = search_agent.run(f"Search for information about: {query}")
319
+
320
+ # 3) Store in cache for future reuse
321
+ store_websearch_result(query, new_search_result)
322
+ return str(new_search_result).strip()
323
  except Exception as e:
324
  print(f"Web search failed: {e}")
325
  return ""
326
 
327
+ ################################################################################
328
+ # Response Merging
329
+ ################################################################################
330
+
331
  def merge_responses(csv_answer: str, web_answer: str) -> str:
332
+ """Merge CSV-based RAG result with web search results."""
333
  try:
334
  if not csv_answer and not web_answer:
335
  return "I apologize, but I couldn't find any relevant information."
 
345
  print(f"Error merging responses: {e}")
346
  return csv_answer or web_answer or "I apologize, but I couldn't process the information properly."
347
 
348
+ ################################################################################
349
+ # Main Pipeline
350
+ ################################################################################
351
+
352
  def run_pipeline(query: str) -> str:
353
+ """
354
+ Pipeline logic to:
355
+ 1) Sanitize & moderate the query
356
+ 2) Classify the query (OutOfScope, Wellness, Brand, etc.)
357
+ 3) If safe & in scope, do RAG + ALWAYS do a cached web search
358
+ 4) Merge responses and tailor final output
359
+ """
360
  try:
361
  print(query)
362
  sanitized_query = sanitize_message(query)
 
365
  topic = extract_main_topic(query_input.query)
366
  moderation_result = moderate_text(query_input.query)
367
 
368
+ # Check for unsafe content
369
  if not moderation_result.is_safe:
370
  return "Sorry, this query contains harmful or inappropriate content."
371
+
372
+ # Classify
373
  classification = classify_query(moderation_result.original_text)
374
 
375
+ # If out-of-scope, refuse
376
  if classification == "OutOfScope":
377
  refusal_text = refusal_chain.run({"topic": topic})
378
  return tailor_chain.run({"response": refusal_text}).strip()
379
+
380
+ # Otherwise, do a RAG query and also do a web search (cached)
381
  if classification == "Wellness":
382
+ # RAG from wellness store
383
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
384
+ csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip()
385
+
386
+ # Always do a (cached) web search
387
+ web_answer = do_cached_web_search(moderation_result.original_text)
388
+
389
+ # Merge CSV & Web
390
  final_merged = merge_responses(csv_answer, web_answer)
391
  return tailor_chain.run({"response": final_merged}).strip()
392
 
393
  if classification == "Brand":
394
+ # RAG from brand store
395
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
396
+ csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip()
397
+
398
+ # Always do a (cached) web search
399
+ web_answer = do_cached_web_search(moderation_result.original_text)
400
+
401
+ # Merge CSV & Web
402
+ final_merged = merge_responses(csv_answer, web_answer)
403
  return tailor_chain.run({"response": final_merged}).strip()
404
 
405
+ # If it doesn't fall under known categories, return refusal by default.
406
  refusal_text = refusal_chain.run({"topic": topic})
407
  return tailor_chain.run({"response": refusal_text}).strip()
408
 
 
412
  raise RuntimeError(f"Error in run_pipeline: {str(e)}")
413
 
414
  def run_with_chain(query: str) -> str:
415
+ """Convenience function to run the main pipeline and handle errors gracefully."""
416
  try:
417
  return run_pipeline(query)
418
  except Exception as e:
419
  print(f"Error in run_with_chain: {str(e)}")
420
  return "I apologize, but I encountered an error processing your request. Please try again."
421
 
422
+ ################################################################################
423
+ # Chain & Vectorstore Initialization
424
+ ################################################################################
425
+
426
+ # Load your classification/refusal/tailor/cleaner chains
427
  classification_chain = get_classification_chain()
428
  refusal_chain = get_refusal_chain()
429
  tailor_chain = get_tailor_chain()
430
  cleaner_chain = get_cleaner_chain()
431
 
432
+ # CSV file paths and store directories for RAG
433
  wellness_csv = "AIChatbot.csv"
434
  brand_csv = "BrandAI.csv"
435
  wellness_store_dir = "faiss_wellness_store"
436
  brand_store_dir = "faiss_brand_store"
437
 
438
+ # Build or load the vector stores
439
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
440
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
441
 
442
+ # Build RAG chains
443
  wellness_rag_chain = build_rag_chain(wellness_vectorstore)
444
  brand_rag_chain = build_rag_chain(brand_vectorstore)
445
 
446
+ print("Pipeline initialized successfully! Ready to handle queries with caching.")