File size: 19,763 Bytes
9dc639f
 
74221f2
9dc639f
936bed3
1eb0002
293661c
936bed3
293661c
9dc639f
 
 
 
936bed3
55d7984
1eb0002
293661c
53b33ac
df1f812
 
 
53b33ac
 
 
 
 
fc48f50
9583bff
6c4ab66
9583bff
 
 
db87ae8
293661c
db87ae8
78bd826
df1f812
936bed3
df1f812
 
0ef5d8a
df1f812
0ef5d8a
1fc8ba9
df1f812
 
90437f0
0dbe352
936bed3
 
 
c947e4c
1eb0002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936bed3
 
 
 
726773c
 
 
 
 
 
 
 
 
 
e8182c5
864c041
936bed3
 
 
 
df1f812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74221f2
936bed3
1eb0002
 
 
 
 
936bed3
1eb0002
 
 
74221f2
1eb0002
936bed3
1eb0002
 
 
 
 
 
 
 
 
 
74221f2
1eb0002
936bed3
54fafa1
1eb0002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936bed3
1eb0002
 
 
 
 
 
 
b0739e4
1eb0002
 
 
db87ae8
19fdb92
936bed3
1eb0002
 
 
936bed3
31cd36e
 
1eb0002
19fdb92
1eb0002
936bed3
1eb0002
fdd11a3
31cd36e
 
1eb0002
19fdb92
1eb0002
 
 
 
0aef3aa
936bed3
 
 
 
53b33ac
1eb0002
 
9724455
1eb0002
9724455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb0002
 
 
53b33ac
df1f812
ac8126f
1eb0002
 
df1f812
ac8126f
1eb0002
 
 
 
df1f812
1eb0002
 
936bed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df1f812
936bed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0036873
936bed3
9c26203
 
df1f812
90437f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936bed3
 
 
 
0036873
df1f812
 
 
936bed3
 
 
 
df1f812
936bed3
df1f812
 
 
0036873
df1f812
 
 
 
 
 
 
 
 
 
 
936bed3
 
 
 
b0739e4
936bed3
 
 
 
 
 
 
1eb0002
df1f812
 
 
 
 
 
 
936bed3
df1f812
 
936bed3
 
fdd11a3
c54b7aa
936bed3
19fdb92
d329916
1eb0002
936bed3
 
19fdb92
936bed3
1eb0002
936bed3
 
 
 
 
 
1eb0002
 
 
19fdb92
936bed3
1eb0002
936bed3
 
 
 
 
 
 
1eb0002
c54b7aa
936bed3
970259d
1eb0002
ae2e497
df1f812
 
 
 
 
 
936bed3
df1f812
 
 
 
 
1eb0002
936bed3
 
 
 
 
3c486f4
 
 
 
 
936bed3
3c486f4
 
 
 
 
936bed3
3c486f4
 
 
936bed3
df1f812
 
3c486f4
0dbe352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
import os
import getpass
import spacy
import pandas as pd
import numpy as np
from typing import Optional, List, Dict, Any
import subprocess

from langchain.llms.base import LLM
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

from smolagents import DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel ,CodeAgent,  HfApiModel
from pydantic import BaseModel, Field, ValidationError, validator
from mistralai import Mistral

# Import Google Gemini model
from langchain_google_genai import ChatGoogleGenerativeAI

from classification_chain import get_classification_chain
from cleaner_chain import get_cleaner_chain
from refusal_chain import get_refusal_chain
from tailor_chain import get_tailor_chain
from prompts import classification_prompt, refusal_prompt, tailor_prompt


LANGSMITH_TRACING=True
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY=os.environ.get("LANGSMITH_API_KEY")
LANGSMITH_PROJECT=os.environ.get("LANGCHAIN_PROJECT")
# Initialize Mistral API client
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
client = Mistral(api_key=mistral_api_key)

# Setup ChatGoogleGenerativeAI for Gemini
# Ensure GEMINI_API_KEY is set in your environment variables.
gemini_llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-pro",
    temperature=0.5,
    max_retries=2,
    google_api_key=os.environ.get("GEMINI_API_KEY"),
    # Additional parameters or safety_settings can be added here if needed
)

# web_gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))

################################################################################
# Pydantic Models
################################################################################

class QueryInput(BaseModel):
    query: str = Field(..., min_length=1, description="The input query string")
    
    @validator('query')
    def check_query_is_string(cls, v):
        if not isinstance(v, str):
            raise ValueError("Query must be a valid string")
        if v.strip() == "":
            raise ValueError("Query cannot be empty or just whitespace")
        return v.strip()

class ModerationResult(BaseModel):
    is_safe: bool = Field(..., description="Whether the content is safe")
    categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
    original_text: str = Field(..., description="The original input text")

################################################################################
# SPACy Setup
################################################################################

def install_spacy_model():
    try:
        spacy.load("en_core_web_sm")
        print("spaCy model 'en_core_web_sm' is already installed.")
    except OSError:
        print("Downloading spaCy model 'en_core_web_sm'...")
        subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
        print("spaCy model 'en_core_web_sm' downloaded successfully.")

install_spacy_model()
nlp = spacy.load("en_core_web_sm")

################################################################################
# Utility Functions
################################################################################

def sanitize_message(message: Any) -> str:
    """Sanitize message input to ensure it's a valid string."""
    try:
        if hasattr(message, 'content'):
            return str(message.content).strip()
        if isinstance(message, dict) and 'content' in message:
            return str(message['content']).strip()
        if isinstance(message, list) and len(message) > 0:
            if isinstance(message[0], dict) and 'content' in message[0]:
                return str(message[0]['content']).strip()
            if hasattr(message[0], 'content'):
                return str(message[0].content).strip()
        return str(message).strip()
    except Exception as e:
        raise RuntimeError(f"Error in sanitize function: {str(e)}")

def extract_main_topic(query: str) -> str:
    """Extracts a main topic (named entity or noun) from the user query."""
    try:
        query_input = QueryInput(query=query)
        doc = nlp(query_input.query)
        main_topic = None
        
        # Attempt to find an entity
        for ent in doc.ents:
            if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
                main_topic = ent.text
                break
        
        # If no named entity, fall back to nouns or proper nouns
        if not main_topic:
            for token in doc:
                if token.pos_ in ["NOUN", "PROPN"]:
                    main_topic = token.text
                    break
        
        return main_topic if main_topic else "this topic"
    except Exception as e:
        print(f"Error extracting main topic: {e}")
        return "this topic"

def moderate_text(query: str) -> ModerationResult:
    """Uses Mistral's moderation to determine if the content is safe."""
    try:
        query_input = QueryInput(query=query)
        
        response = client.classifiers.moderate_chat(
            model="mistral-moderation-latest",
            inputs=[{"role": "user", "content": query_input.query}]
        )
        
        is_safe = True
        categories = {}
        
        if hasattr(response, 'results') and response.results:
            categories = {
                "violence": response.results[0].categories.get("violence_and_threats", False),
                "hate": response.results[0].categories.get("hate_and_discrimination", False),
                "dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False),
                "selfharm": response.results[0].categories.get("selfharm", False)
            }
            # If any flagged category is True, then not safe
            is_safe = not any(categories.values())
        
        return ModerationResult(
            is_safe=is_safe,
            categories=categories,
            original_text=query_input.query
        )
    except ValidationError as e:
        raise ValueError(f"Input validation failed: {str(e)}")
    except Exception as e:
        raise RuntimeError(f"Moderation failed: {str(e)}")

def classify_query(query: str) -> str:
    """Classify user query into known categories using your classification chain."""
    try:
        query_input = QueryInput(query=query)
        
        # Quick pattern-based approach for 'Wellness'
        # wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
        wellness_keywords=[]
        if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
            return "Wellness"
        
        # Use chain for everything else
        class_result = classification_chain.invoke({"query": query_input.query})
        print(class_result)
        # classification = class_result.get("text", "").strip()
        classification=class_result
        
        return classification if classification != "" else "OutOfScope"
    except ValidationError as e:
        raise ValueError(f"Classification input validation failed: {str(e)}")
    except Exception as e:
        raise RuntimeError(f"Classification failed: {str(e)}")

################################################################################
# Vector Store Building/Loading
################################################################################

def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
    try:
        if os.path.exists(store_dir):
            print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
            vectorstore = FAISS.load_local(store_dir, embeddings)
            return vectorstore
        else:
            print(f"DEBUG: Building new store from CSV: {csv_path}")
            df = pd.read_csv(csv_path)
            df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
            df.columns = df.columns.str.strip()
            if "Answer" in df.columns:
                df.rename(columns={"Answer": "Answers"}, inplace=True)
            if "Question" not in df.columns and "Question " in df.columns:
                df.rename(columns={"Question ": "Question"}, inplace=True)
            if "Question" not in df.columns or "Answers" not in df.columns:
                raise ValueError("CSV must have 'Question' and 'Answers' columns.")
            docs = []
            for _, row in df.iterrows():
                q = str(row["Question"])
                ans = str(row["Answers"])
                doc = Document(page_content=ans, metadata={"question": q})
                docs.append(doc)
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
            vectorstore = FAISS.from_documents(docs, embedding=embeddings)
            vectorstore.save_local(store_dir)
            return vectorstore
        
    except Exception as e:
        raise RuntimeError(f"Error building/loading vector store: {str(e)}")

def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
    """Build RAG chain using the Gemini LLM directly without a custom class."""
    try:
        retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
        chain = RetrievalQA.from_chain_type(
            llm=gemini_llm,  # Directly use the ChatGoogleGenerativeAI instance
            chain_type="stuff",
            retriever=retriever,
            return_source_documents=True
        )
        return chain
    except Exception as e:
        raise RuntimeError(f"Error building RAG chain: {str(e)}")
################################################################################
# Web Search Caching: Separate FAISS Vector Store
################################################################################

# Directory for storing cached web search results
web_search_store_dir = "faiss_websearch_store"

def build_or_load_websearch_store(store_dir: str) -> FAISS:
    """
    Builds or loads a FAISS vector store for caching web search results.
    Each Document will have page_content as the search result text,
    and metadata={"question": <user_query>}.
    """
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
    if os.path.exists(store_dir):
        print(f"DEBUG: Found existing WebSearch FAISS store at '{store_dir}'. Loading...")
        return FAISS.load_local(store_dir, embeddings)
    else:
        print(f"DEBUG: Creating a new, empty WebSearch FAISS store at '{store_dir}'...")
        # Start empty
        empty_store = FAISS.from_texts([""], embeddings, metadatas=[{"question": "placeholder"}])
        # Remove the placeholder doc so we don't retrieve it
        empty_store.index.reset()
        empty_store.docstore._dict = {}
        empty_store.save_local(store_dir)
        return empty_store

# Initialize the web search vector store
web_search_vectorstore = build_or_load_websearch_store(web_search_store_dir)
websearch_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")

def compute_cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
    """Compute cosine similarity between two embedding vectors."""
    a = np.array(vec_a, dtype=float)
    b = np.array(vec_b, dtype=float)
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10))

def get_cached_websearch(query: str, threshold: float = 0.8) -> Optional[str]:
    """
    Attempts to retrieve a cached web search result for a given query.
    If the top retrieved document has a cosine similarity >= threshold,
    returns that document's page_content. Otherwise, returns None.
    """
    # Retrieve the top doc from the store
    retriever = web_search_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 1})
    results = retriever.get_relevant_documents(query)
    if not results:
        return None
    
    # Compare similarity with the top doc
    top_doc = results[0]
    query_vec = websearch_embeddings.embed_query(query)
    doc_vec = websearch_embeddings.embed_query(top_doc.page_content)
    similarity = compute_cosine_similarity(query_vec, doc_vec)
    
    if similarity >= threshold:
        print(f"DEBUG: Using cached web search (similarity={similarity:.2f} >= {threshold})")
        return top_doc.page_content
    
    print(f"DEBUG: Cached doc similarity={similarity:.2f} < {threshold}, not reusing.")
    return None

def store_websearch_result(query: str, web_search_text: str):
    """
    Embeds and stores the web search result text in the web search vector store,
    keyed by the question in metadata. Then saves the store locally.
    """
    if not web_search_text.strip():
        return  # Don't store empty results
    doc = Document(page_content=web_search_text, metadata={"question": query})
    web_search_vectorstore.add_documents([doc], embedding=websearch_embeddings)
    web_search_vectorstore.save_local(web_search_store_dir)

def do_cached_web_search(query: str) -> str:
    """Perform a DuckDuckGo web search, but with caching via FAISS vector store."""
    # 1) Check cache
    cached_result = get_cached_websearch(query)
    if cached_result:
        return cached_result
    
    # 2) If no suitable cached answer, do a new search
    try:
        print("DEBUG: Performing a new web search...")
        # model =  LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
        model=HfApiModel()
        search_tool = DuckDuckGoSearchTool()
        web_agent = CodeAgent(
        tools=[search_tool],
        model=model
        )
        
        managed_web_agent = ManagedAgent(
        agent=web_agent,
        name="web_search",
        description="Runs a web search for you. Provide your query as an argument."
        )
        
        manager_agent = CodeAgent(
        tools=[],  # If you have additional tools for the manager, add them here
        model=model,
        managed_agents=[managed_web_agent]
        )
        
        new_search_result = manager_agent.run(f"Search for information about: {query}")
        
        # 3) Store in cache for future reuse
        store_websearch_result(query, new_search_result)
        return str(new_search_result).strip()
    except Exception as e:
        print(f"Web search failed: {e}")
        return ""

################################################################################
# Response Merging
################################################################################

def merge_responses(csv_answer: str, web_answer: str) -> str:
    """Merge CSV-based RAG result with web search results."""
    try:
        if not csv_answer and not web_answer:
            return "I apologize, but I couldn't find any relevant information."
        
        if not web_answer:
            return csv_answer
            
        if not csv_answer:
            return web_answer
            
        return f"{csv_answer}\n\nAdditional information from web search:\n{web_answer}"
    except Exception as e:
        print(f"Error merging responses: {e}")
        return csv_answer or web_answer or "I apologize, but I couldn't process the information properly."

################################################################################
# Main Pipeline
################################################################################

def run_pipeline(query: str) -> str:
    """
    Pipeline logic to:
      1) Sanitize & moderate the query
      2) Classify the query (OutOfScope, Wellness, Brand, etc.)
      3) If safe & in scope, do RAG + ALWAYS do a cached web search
      4) Merge responses and tailor final output
    """
    try:
        print(query)
        sanitized_query = sanitize_message(query)
        query_input = QueryInput(query=sanitized_query)
        
        topic = extract_main_topic(query_input.query)
        moderation_result = moderate_text(query_input.query)
        
        # Check for unsafe content
        if not moderation_result.is_safe:
            return "Sorry, this query contains harmful or inappropriate content."
        
        # Classify
        classification = classify_query(moderation_result.original_text)
        
        # If out-of-scope, refuse
        if classification == "OutOfScope":
            refusal_text = refusal_chain.invoke({"topic": topic,"query":query})
            return tailor_chain.run({"response": refusal_text}).strip()
        
        # Otherwise, do a RAG query and also do a web search (cached)
        if classification == "Wellness":
            # RAG from wellness store
            rag_result = wellness_rag_chain({"query": moderation_result.original_text})
            csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip()

            # Always do a (cached) web search
            web_answer = do_cached_web_search(moderation_result.original_text)

            # Merge CSV & Web
            final_merged = merge_responses(csv_answer, web_answer)
            return tailor_chain.run({"response": final_merged}).strip()

        if classification == "Brand":
            # RAG from brand store
            rag_result = brand_rag_chain({"query": moderation_result.original_text})
            csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip()

            # Always do a (cached) web search
            web_answer = do_cached_web_search(moderation_result.original_text)

            # Merge CSV & Web
            final_merged = merge_responses(csv_answer, web_answer)
            return tailor_chain.run({"response": final_merged}).strip()
        
        # If it doesn't fall under known categories, return refusal by default.
        refusal_text = refusal_chain.invoke({"topic": topic,"query":query})
        return tailor_chain.run({"response": refusal_text}).strip()
        
    except ValidationError as e:
        raise ValueError(f"Input validation failed: {str(e)}")
    except Exception as e:
        raise RuntimeError(f"Error in run_pipeline: {str(e)}")

def run_with_chain(query: str) -> str:
    """Convenience function to run the main pipeline and handle errors gracefully."""
    try:
        return run_pipeline(query)
    except Exception as e:
        print(f"Error in run_with_chain: {str(e)}")
        return "I apologize, but I encountered an error processing your request. Please try again."

################################################################################
# Chain & Vectorstore Initialization
################################################################################

# Load your classification/refusal/tailor/cleaner chains
classification_chain = get_classification_chain()
refusal_chain = get_refusal_chain()
tailor_chain = get_tailor_chain()
cleaner_chain = get_cleaner_chain()

# CSV file paths and store directories for RAG
wellness_csv = "AIChatbot.csv"
brand_csv = "BrandAI.csv"
wellness_store_dir = "faiss_wellness_store"
brand_store_dir = "faiss_brand_store"

# Build or load the vector stores
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)

# Build RAG chains
wellness_rag_chain = build_rag_chain(wellness_vectorstore)
brand_rag_chain = build_rag_chain(brand_vectorstore)

print("Pipeline initialized successfully! Ready to handle querie with caching.")