Spaces:
Sleeping
Sleeping
| import os | |
| import getpass | |
| import spacy | |
| import pandas as pd | |
| import numpy as np | |
| from typing import Optional, List, Dict, Any | |
| import subprocess | |
| from langchain.llms.base import LLM | |
| from langchain.docstore.document import Document | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from smolagents import DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel ,CodeAgent, HfApiModel | |
| from pydantic import BaseModel, Field, ValidationError, validator | |
| from mistralai import Mistral | |
| # Import Google Gemini model | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from classification_chain import get_classification_chain | |
| from cleaner_chain import get_cleaner_chain | |
| from refusal_chain import get_refusal_chain | |
| from tailor_chain import get_tailor_chain | |
| from prompts import classification_prompt, refusal_prompt, tailor_prompt | |
| LANGSMITH_TRACING=True | |
| LANGSMITH_ENDPOINT="https://api.smith.langchain.com" | |
| LANGSMITH_API_KEY=os.environ.get("LANGSMITH_API_KEY") | |
| LANGSMITH_PROJECT=os.environ.get("LANGCHAIN_PROJECT") | |
| # Initialize Mistral API client | |
| mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
| client = Mistral(api_key=mistral_api_key) | |
| # Setup ChatGoogleGenerativeAI for Gemini | |
| # Ensure GEMINI_API_KEY is set in your environment variables. | |
| gemini_llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-pro", | |
| temperature=0.5, | |
| max_retries=2, | |
| google_api_key=os.environ.get("GEMINI_API_KEY"), | |
| # Additional parameters or safety_settings can be added here if needed | |
| ) | |
| # web_gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) | |
| ################################################################################ | |
| # Pydantic Models | |
| ################################################################################ | |
| class QueryInput(BaseModel): | |
| query: str = Field(..., min_length=1, description="The input query string") | |
| def check_query_is_string(cls, v): | |
| if not isinstance(v, str): | |
| raise ValueError("Query must be a valid string") | |
| if v.strip() == "": | |
| raise ValueError("Query cannot be empty or just whitespace") | |
| return v.strip() | |
| class ModerationResult(BaseModel): | |
| is_safe: bool = Field(..., description="Whether the content is safe") | |
| categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories") | |
| original_text: str = Field(..., description="The original input text") | |
| ################################################################################ | |
| # SPACy Setup | |
| ################################################################################ | |
| def install_spacy_model(): | |
| try: | |
| spacy.load("en_core_web_sm") | |
| print("spaCy model 'en_core_web_sm' is already installed.") | |
| except OSError: | |
| print("Downloading spaCy model 'en_core_web_sm'...") | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
| print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
| install_spacy_model() | |
| nlp = spacy.load("en_core_web_sm") | |
| ################################################################################ | |
| # Utility Functions | |
| ################################################################################ | |
| def sanitize_message(message: Any) -> str: | |
| """Sanitize message input to ensure it's a valid string.""" | |
| try: | |
| if hasattr(message, 'content'): | |
| return str(message.content).strip() | |
| if isinstance(message, dict) and 'content' in message: | |
| return str(message['content']).strip() | |
| if isinstance(message, list) and len(message) > 0: | |
| if isinstance(message[0], dict) and 'content' in message[0]: | |
| return str(message[0]['content']).strip() | |
| if hasattr(message[0], 'content'): | |
| return str(message[0].content).strip() | |
| return str(message).strip() | |
| except Exception as e: | |
| raise RuntimeError(f"Error in sanitize function: {str(e)}") | |
| def extract_main_topic(query: str) -> str: | |
| """Extracts a main topic (named entity or noun) from the user query.""" | |
| try: | |
| query_input = QueryInput(query=query) | |
| doc = nlp(query_input.query) | |
| main_topic = None | |
| # Attempt to find an entity | |
| for ent in doc.ents: | |
| if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: | |
| main_topic = ent.text | |
| break | |
| # If no named entity, fall back to nouns or proper nouns | |
| if not main_topic: | |
| for token in doc: | |
| if token.pos_ in ["NOUN", "PROPN"]: | |
| main_topic = token.text | |
| break | |
| return main_topic if main_topic else "this topic" | |
| except Exception as e: | |
| print(f"Error extracting main topic: {e}") | |
| return "this topic" | |
| def moderate_text(query: str) -> ModerationResult: | |
| """Uses Mistral's moderation to determine if the content is safe.""" | |
| try: | |
| query_input = QueryInput(query=query) | |
| response = client.classifiers.moderate_chat( | |
| model="mistral-moderation-latest", | |
| inputs=[{"role": "user", "content": query_input.query}] | |
| ) | |
| is_safe = True | |
| categories = {} | |
| if hasattr(response, 'results') and response.results: | |
| categories = { | |
| "violence": response.results[0].categories.get("violence_and_threats", False), | |
| "hate": response.results[0].categories.get("hate_and_discrimination", False), | |
| "dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False), | |
| "selfharm": response.results[0].categories.get("selfharm", False) | |
| } | |
| # If any flagged category is True, then not safe | |
| is_safe = not any(categories.values()) | |
| return ModerationResult( | |
| is_safe=is_safe, | |
| categories=categories, | |
| original_text=query_input.query | |
| ) | |
| except ValidationError as e: | |
| raise ValueError(f"Input validation failed: {str(e)}") | |
| except Exception as e: | |
| raise RuntimeError(f"Moderation failed: {str(e)}") | |
| def classify_query(query: str) -> str: | |
| """Classify user query into known categories using your classification chain.""" | |
| try: | |
| query_input = QueryInput(query=query) | |
| # Quick pattern-based approach for 'Wellness' | |
| wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"] | |
| if any(keyword in query_input.query.lower() for keyword in wellness_keywords): | |
| return "Wellness" | |
| # Use chain for everything else | |
| class_result = classification_chain.invoke({"query": query_input.query}) | |
| classification = class_result.get("text", "").strip() | |
| return classification if classification != "" else "OutOfScope" | |
| except ValidationError as e: | |
| raise ValueError(f"Classification input validation failed: {str(e)}") | |
| except Exception as e: | |
| raise RuntimeError(f"Classification failed: {str(e)}") | |
| ################################################################################ | |
| # Vector Store Building/Loading | |
| ################################################################################ | |
| def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
| """ | |
| Builds or loads a FAISS vector store for CSV documents containing 'Question' and 'Answers'. | |
| """ | |
| try: | |
| if os.path.exists(store_dir): | |
| print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...") | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| vectorstore = FAISS.load_local(store_dir, embeddings) | |
| return vectorstore | |
| else: | |
| print(f"DEBUG: Building new store from CSV: {csv_path}") | |
| df = pd.read_csv(csv_path) | |
| df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
| df.columns = df.columns.str.strip() | |
| # Fix possible column name variations | |
| if "Answer" in df.columns: | |
| df.rename(columns={"Answer": "Answers"}, inplace=True) | |
| if "Question" not in df.columns and "Question " in df.columns: | |
| df.rename(columns={"Question ": "Question"}, inplace=True) | |
| if "Question" not in df.columns or "Answers" not in df.columns: | |
| raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
| docs = [] | |
| for _, row in df.iterrows(): | |
| q = str(row["Question"]) | |
| ans = str(row["Answers"]) | |
| doc = Document(page_content=ans, metadata={"question": q}) | |
| docs.append(doc) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
| vectorstore.save_local(store_dir) | |
| return vectorstore | |
| except Exception as e: | |
| raise RuntimeError(f"Error building/loading vector store: {str(e)}") | |
| def build_rag_chain(vectorstore: FAISS) -> RetrievalQA: | |
| """Build RAG chain using the Gemini LLM.""" | |
| try: | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| chain = RetrievalQA.from_chain_type( | |
| llm=gemini_llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True | |
| ) | |
| return chain | |
| except Exception as e: | |
| raise RuntimeError(f"Error building RAG chain: {str(e)}") | |
| ################################################################################ | |
| # Web Search Caching: Separate FAISS Vector Store | |
| ################################################################################ | |
| # Directory for storing cached web search results | |
| web_search_store_dir = "faiss_websearch_store" | |
| def build_or_load_websearch_store(store_dir: str) -> FAISS: | |
| """ | |
| Builds or loads a FAISS vector store for caching web search results. | |
| Each Document will have page_content as the search result text, | |
| and metadata={"question": <user_query>}. | |
| """ | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| if os.path.exists(store_dir): | |
| print(f"DEBUG: Found existing WebSearch FAISS store at '{store_dir}'. Loading...") | |
| return FAISS.load_local(store_dir, embeddings) | |
| else: | |
| print(f"DEBUG: Creating a new, empty WebSearch FAISS store at '{store_dir}'...") | |
| # Start empty | |
| empty_store = FAISS.from_texts([""], embeddings, metadatas=[{"question": "placeholder"}]) | |
| # Remove the placeholder doc so we don't retrieve it | |
| empty_store.index.reset() | |
| empty_store.docstore._dict = {} | |
| empty_store.save_local(store_dir) | |
| return empty_store | |
| # Initialize the web search vector store | |
| web_search_vectorstore = build_or_load_websearch_store(web_search_store_dir) | |
| websearch_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| def compute_cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: | |
| """Compute cosine similarity between two embedding vectors.""" | |
| a = np.array(vec_a, dtype=float) | |
| b = np.array(vec_b, dtype=float) | |
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)) | |
| def get_cached_websearch(query: str, threshold: float = 0.8) -> Optional[str]: | |
| """ | |
| Attempts to retrieve a cached web search result for a given query. | |
| If the top retrieved document has a cosine similarity >= threshold, | |
| returns that document's page_content. Otherwise, returns None. | |
| """ | |
| # Retrieve the top doc from the store | |
| retriever = web_search_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 1}) | |
| results = retriever.get_relevant_documents(query) | |
| if not results: | |
| return None | |
| # Compare similarity with the top doc | |
| top_doc = results[0] | |
| query_vec = websearch_embeddings.embed_query(query) | |
| doc_vec = websearch_embeddings.embed_query(top_doc.page_content) | |
| similarity = compute_cosine_similarity(query_vec, doc_vec) | |
| if similarity >= threshold: | |
| print(f"DEBUG: Using cached web search (similarity={similarity:.2f} >= {threshold})") | |
| return top_doc.page_content | |
| print(f"DEBUG: Cached doc similarity={similarity:.2f} < {threshold}, not reusing.") | |
| return None | |
| def store_websearch_result(query: str, web_search_text: str): | |
| """ | |
| Embeds and stores the web search result text in the web search vector store, | |
| keyed by the question in metadata. Then saves the store locally. | |
| """ | |
| if not web_search_text.strip(): | |
| return # Don't store empty results | |
| doc = Document(page_content=web_search_text, metadata={"question": query}) | |
| web_search_vectorstore.add_documents([doc], embedding=websearch_embeddings) | |
| web_search_vectorstore.save_local(web_search_store_dir) | |
| def do_cached_web_search(query: str) -> str: | |
| """Perform a DuckDuckGo web search, but with caching via FAISS vector store.""" | |
| # 1) Check cache | |
| cached_result = get_cached_websearch(query) | |
| if cached_result: | |
| return cached_result | |
| # 2) If no suitable cached answer, do a new search | |
| try: | |
| print("DEBUG: Performing a new web search...") | |
| # model = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) | |
| model=HfApiModel() | |
| search_tool = DuckDuckGoSearchTool() | |
| web_agent = CodeAgent( | |
| tools=[search_tool], | |
| model=model | |
| ) | |
| managed_web_agent = ManagedAgent( | |
| agent=web_agent, | |
| name="web_search", | |
| description="Runs a web search for you. Provide your query as an argument." | |
| ) | |
| manager_agent = CodeAgent( | |
| tools=[], # If you have additional tools for the manager, add them here | |
| model=model, | |
| managed_agents=[managed_web_agent] | |
| ) | |
| new_search_result = manager_agent.run(f"Search for information about: {query}") | |
| # 3) Store in cache for future reuse | |
| store_websearch_result(query, new_search_result) | |
| return str(new_search_result).strip() | |
| except Exception as e: | |
| print(f"Web search failed: {e}") | |
| return "" | |
| ################################################################################ | |
| # Response Merging | |
| ################################################################################ | |
| def merge_responses(csv_answer: str, web_answer: str) -> str: | |
| """Merge CSV-based RAG result with web search results.""" | |
| try: | |
| if not csv_answer and not web_answer: | |
| return "I apologize, but I couldn't find any relevant information." | |
| if not web_answer: | |
| return csv_answer | |
| if not csv_answer: | |
| return web_answer | |
| return f"{csv_answer}\n\nAdditional information from web search:\n{web_answer}" | |
| except Exception as e: | |
| print(f"Error merging responses: {e}") | |
| return csv_answer or web_answer or "I apologize, but I couldn't process the information properly." | |
| ################################################################################ | |
| # Main Pipeline | |
| ################################################################################ | |
| def run_pipeline(query: str) -> str: | |
| """ | |
| Pipeline logic to: | |
| 1) Sanitize & moderate the query | |
| 2) Classify the query (OutOfScope, Wellness, Brand, etc.) | |
| 3) If safe & in scope, do RAG + ALWAYS do a cached web search | |
| 4) Merge responses and tailor final output | |
| """ | |
| try: | |
| print(query) | |
| sanitized_query = sanitize_message(query) | |
| query_input = QueryInput(query=sanitized_query) | |
| topic = extract_main_topic(query_input.query) | |
| moderation_result = moderate_text(query_input.query) | |
| # Check for unsafe content | |
| if not moderation_result.is_safe: | |
| return "Sorry, this query contains harmful or inappropriate content." | |
| # Classify | |
| classification = classificatin_chain.invoke({"query":"moderation_result.original_text"}) | |
| # If out-of-scope, refuse | |
| if classification == "OutOfScope": | |
| refusal_text = refusal_chain.run({"topic": topic}) | |
| return tailor_chain.run({"response": refusal_text}).strip() | |
| # Otherwise, do a RAG query and also do a web search (cached) | |
| if classification == "Wellness": | |
| # RAG from wellness store | |
| rag_result = wellness_rag_chain({"query": moderation_result.original_text}) | |
| csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip() | |
| # Always do a (cached) web search | |
| web_answer = do_cached_web_search(moderation_result.original_text) | |
| # Merge CSV & Web | |
| final_merged = merge_responses(csv_answer, web_answer) | |
| return tailor_chain.run({"response": final_merged}).strip() | |
| if classification == "Brand": | |
| # RAG from brand store | |
| rag_result = brand_rag_chain({"query": moderation_result.original_text}) | |
| csv_answer = rag_result.get("result", "").strip() if isinstance(rag_result, dict) else str(rag_result).strip() | |
| # Always do a (cached) web search | |
| web_answer = do_cached_web_search(moderation_result.original_text) | |
| # Merge CSV & Web | |
| final_merged = merge_responses(csv_answer, web_answer) | |
| return tailor_chain.run({"response": final_merged}).strip() | |
| # If it doesn't fall under known categories, return refusal by default. | |
| refusal_text = refusal_chain.run({"topic": topic}) | |
| return tailor_chain.run({"response": refusal_text}).strip() | |
| except ValidationError as e: | |
| raise ValueError(f"Input validation failed: {str(e)}") | |
| except Exception as e: | |
| raise RuntimeError(f"Error in run_pipeline: {str(e)}") | |
| def run_with_chain(query: str) -> str: | |
| """Convenience function to run the main pipeline and handle errors gracefully.""" | |
| try: | |
| return run_pipeline(query) | |
| except Exception as e: | |
| print(f"Error in run_with_chain: {str(e)}") | |
| return "I apologize, but I encountered an error processing your request. Please try again." | |
| ################################################################################ | |
| # Chain & Vectorstore Initialization | |
| ################################################################################ | |
| # Load your classification/refusal/tailor/cleaner chains | |
| classification_chain = get_classification_chain() | |
| refusal_chain = get_refusal_chain() | |
| tailor_chain = get_tailor_chain() | |
| cleaner_chain = get_cleaner_chain() | |
| # CSV file paths and store directories for RAG | |
| wellness_csv = "AIChatbot.csv" | |
| brand_csv = "BrandAI.csv" | |
| wellness_store_dir = "faiss_wellness_store" | |
| brand_store_dir = "faiss_brand_store" | |
| # Build or load the vector stores | |
| wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) | |
| brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) | |
| # Build RAG chains | |
| wellness_rag_chain = build_rag_chain(wellness_vectorstore) | |
| brand_rag_chain = build_rag_chain(brand_vectorstore) | |
| print("Pipeline initialized successfully! Ready to handle querie with caching.") | |