Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +26 -25
pipeline.py
CHANGED
|
@@ -4,14 +4,13 @@ import spacy
|
|
| 4 |
import pandas as pd
|
| 5 |
from typing import Optional
|
| 6 |
import subprocess
|
| 7 |
-
import asyncio # Needed for managing async tasks
|
| 8 |
from langchain.llms.base import LLM
|
| 9 |
from langchain.docstore.document import Document
|
| 10 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 11 |
from langchain.vectorstores import FAISS
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 14 |
-
from
|
| 15 |
from mistralai import Mistral
|
| 16 |
from langchain.prompts import PromptTemplate
|
| 17 |
|
|
@@ -26,9 +25,6 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
| 26 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 27 |
client = Mistral(api_key=mistral_api_key)
|
| 28 |
|
| 29 |
-
# Initialize Pydantic AI Agent (for text validation)
|
| 30 |
-
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
| 31 |
-
|
| 32 |
# Load spaCy model for NER and download it if not already installed
|
| 33 |
def install_spacy_model():
|
| 34 |
try:
|
|
@@ -67,19 +63,31 @@ def classify_query(query: str) -> str:
|
|
| 67 |
classification = class_result.get("text", "").strip()
|
| 68 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
try:
|
| 73 |
-
#
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
print(f"Error validating text: {e}")
|
| 77 |
return "Invalid text format."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Call the Mistral moderation API
|
| 80 |
response = client.classifiers.moderate_chat(
|
| 81 |
model="mistral-moderation-latest",
|
| 82 |
-
inputs=[{"role": "user", "content":
|
| 83 |
)
|
| 84 |
|
| 85 |
# Assuming the response is an object of type 'ClassificationResponse',
|
|
@@ -93,7 +101,7 @@ def moderate_text(query: str) -> str:
|
|
| 93 |
categories.get("selfharm", False):
|
| 94 |
return "OutOfScope"
|
| 95 |
|
| 96 |
-
return
|
| 97 |
|
| 98 |
|
| 99 |
# Function to build or load the vector store from CSV data
|
|
@@ -147,7 +155,7 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
| 147 |
return rag_chain
|
| 148 |
|
| 149 |
# Function to perform web search using DuckDuckGo
|
| 150 |
-
|
| 151 |
search_tool = DuckDuckGoSearchTool()
|
| 152 |
web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
|
| 153 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
@@ -158,13 +166,13 @@ async def do_web_search(query: str) -> str:
|
|
| 158 |
return response
|
| 159 |
|
| 160 |
# Function to combine web and knowledge base responses
|
| 161 |
-
|
| 162 |
# Merge both answers with a cohesive response
|
| 163 |
final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
|
| 164 |
return final_answer.strip()
|
| 165 |
|
| 166 |
# Orchestrate the entire workflow
|
| 167 |
-
|
| 168 |
# Moderate the query for harmful content (sync)
|
| 169 |
moderated_query = moderate_text(query)
|
| 170 |
if moderated_query == "OutOfScope":
|
|
@@ -183,15 +191,15 @@ async def run_async_pipeline(query: str) -> str:
|
|
| 183 |
csv_answer = rag_result["result"].strip()
|
| 184 |
web_answer = "" # Empty if we found an answer from the knowledge base
|
| 185 |
if not csv_answer:
|
| 186 |
-
web_answer =
|
| 187 |
-
final_merged =
|
| 188 |
final_answer = tailor_chain.run({"response": final_merged})
|
| 189 |
return final_answer.strip()
|
| 190 |
|
| 191 |
if classification == "Brand":
|
| 192 |
rag_result = brand_rag_chain({"query": moderated_query})
|
| 193 |
csv_answer = rag_result["result"].strip()
|
| 194 |
-
final_merged =
|
| 195 |
final_answer = tailor_chain.run({"response": final_merged})
|
| 196 |
return final_answer.strip()
|
| 197 |
|
|
@@ -199,13 +207,6 @@ async def run_async_pipeline(query: str) -> str:
|
|
| 199 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
| 200 |
return final_refusal.strip()
|
| 201 |
|
| 202 |
-
# Run the pipeline with the event loop
|
| 203 |
-
import asyncio
|
| 204 |
-
|
| 205 |
-
def run_with_chain(query: str) -> str:
|
| 206 |
-
# Use asyncio.run to run the async pipeline, which ensures a fresh event loop
|
| 207 |
-
return asyncio.run(run_async_pipeline(query))
|
| 208 |
-
|
| 209 |
# Initialize chains here
|
| 210 |
classification_chain = get_classification_chain()
|
| 211 |
refusal_chain = get_refusal_chain()
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
from typing import Optional
|
| 6 |
import subprocess
|
|
|
|
| 7 |
from langchain.llms.base import LLM
|
| 8 |
from langchain.docstore.document import Document
|
| 9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
| 10 |
from langchain.vectorstores import FAISS
|
| 11 |
from langchain.chains import RetrievalQA
|
| 12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 13 |
+
from pydantic import BaseModel, ValidationError # Import Pydantic for text validation
|
| 14 |
from mistralai import Mistral
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
|
|
|
|
| 25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 26 |
client = Mistral(api_key=mistral_api_key)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
# Load spaCy model for NER and download it if not already installed
|
| 29 |
def install_spacy_model():
|
| 30 |
try:
|
|
|
|
| 63 |
classification = class_result.get("text", "").strip()
|
| 64 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 65 |
|
| 66 |
+
# Pydantic model for text validation
|
| 67 |
+
class TextInputModel(BaseModel):
|
| 68 |
+
text: str
|
| 69 |
+
|
| 70 |
+
# Function to validate the text input using Pydantic
|
| 71 |
+
def validate_text(query: str) -> str:
|
| 72 |
try:
|
| 73 |
+
# Attempt to validate the query as a text input
|
| 74 |
+
TextInputModel(text=query)
|
| 75 |
+
return query
|
| 76 |
+
except ValidationError as e:
|
| 77 |
print(f"Error validating text: {e}")
|
| 78 |
return "Invalid text format."
|
| 79 |
+
|
| 80 |
+
# Function to moderate text using Mistral moderation API (synchronous version)
|
| 81 |
+
def moderate_text(query: str) -> str:
|
| 82 |
+
# Validate the text using Pydantic
|
| 83 |
+
validated_text = validate_text(query)
|
| 84 |
+
if validated_text == "Invalid text format.":
|
| 85 |
+
return validated_text
|
| 86 |
|
| 87 |
# Call the Mistral moderation API
|
| 88 |
response = client.classifiers.moderate_chat(
|
| 89 |
model="mistral-moderation-latest",
|
| 90 |
+
inputs=[{"role": "user", "content": validated_text}]
|
| 91 |
)
|
| 92 |
|
| 93 |
# Assuming the response is an object of type 'ClassificationResponse',
|
|
|
|
| 101 |
categories.get("selfharm", False):
|
| 102 |
return "OutOfScope"
|
| 103 |
|
| 104 |
+
return validated_text
|
| 105 |
|
| 106 |
|
| 107 |
# Function to build or load the vector store from CSV data
|
|
|
|
| 155 |
return rag_chain
|
| 156 |
|
| 157 |
# Function to perform web search using DuckDuckGo
|
| 158 |
+
def do_web_search(query: str) -> str:
|
| 159 |
search_tool = DuckDuckGoSearchTool()
|
| 160 |
web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
|
| 161 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
|
| 166 |
return response
|
| 167 |
|
| 168 |
# Function to combine web and knowledge base responses
|
| 169 |
+
def merge_responses(kb_answer: str, web_answer: str) -> str:
|
| 170 |
# Merge both answers with a cohesive response
|
| 171 |
final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
|
| 172 |
return final_answer.strip()
|
| 173 |
|
| 174 |
# Orchestrate the entire workflow
|
| 175 |
+
def run_pipeline(query: str) -> str:
|
| 176 |
# Moderate the query for harmful content (sync)
|
| 177 |
moderated_query = moderate_text(query)
|
| 178 |
if moderated_query == "OutOfScope":
|
|
|
|
| 191 |
csv_answer = rag_result["result"].strip()
|
| 192 |
web_answer = "" # Empty if we found an answer from the knowledge base
|
| 193 |
if not csv_answer:
|
| 194 |
+
web_answer = do_web_search(moderated_query)
|
| 195 |
+
final_merged = merge_responses(csv_answer, web_answer)
|
| 196 |
final_answer = tailor_chain.run({"response": final_merged})
|
| 197 |
return final_answer.strip()
|
| 198 |
|
| 199 |
if classification == "Brand":
|
| 200 |
rag_result = brand_rag_chain({"query": moderated_query})
|
| 201 |
csv_answer = rag_result["result"].strip()
|
| 202 |
+
final_merged = merge_responses(csv_answer, "")
|
| 203 |
final_answer = tailor_chain.run({"response": final_merged})
|
| 204 |
return final_answer.strip()
|
| 205 |
|
|
|
|
| 207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
| 208 |
return final_refusal.strip()
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
# Initialize chains here
|
| 211 |
classification_chain = get_classification_chain()
|
| 212 |
refusal_chain = get_refusal_chain()
|