Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +28 -24
pipeline.py
CHANGED
@@ -10,7 +10,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
10 |
from langchain.vectorstores import FAISS
|
11 |
from langchain.chains import RetrievalQA
|
12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
13 |
-
from pydantic import BaseModel, ValidationError
|
14 |
from mistralai import Mistral
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
|
@@ -25,6 +25,9 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
26 |
client = Mistral(api_key=mistral_api_key)
|
27 |
|
|
|
|
|
|
|
28 |
# Load spaCy model for NER and download it if not already installed
|
29 |
def install_spacy_model():
|
30 |
try:
|
@@ -53,6 +56,17 @@ def extract_main_topic(query: str) -> str:
|
|
53 |
break
|
54 |
return main_topic if main_topic else "this topic"
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Function to classify query based on wellness topics
|
57 |
def classify_query(query: str) -> str:
|
58 |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
|
@@ -63,45 +77,31 @@ def classify_query(query: str) -> str:
|
|
63 |
classification = class_result.get("text", "").strip()
|
64 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
text: str
|
69 |
-
|
70 |
-
# Function to validate the text input using Pydantic
|
71 |
-
def validate_text(query: str) -> str:
|
72 |
try:
|
73 |
-
#
|
74 |
-
|
75 |
-
return query
|
76 |
except ValidationError as e:
|
77 |
print(f"Error validating text: {e}")
|
78 |
return "Invalid text format."
|
79 |
-
|
80 |
-
# Function to moderate text using Mistral moderation API (synchronous version)
|
81 |
-
def moderate_text(query: str) -> str:
|
82 |
-
# Validate the text using Pydantic
|
83 |
-
validated_text = validate_text(query)
|
84 |
-
if validated_text == "Invalid text format.":
|
85 |
-
return validated_text
|
86 |
|
87 |
# Call the Mistral moderation API
|
88 |
response = client.classifiers.moderate_chat(
|
89 |
model="mistral-moderation-latest",
|
90 |
-
inputs=[{"role": "user", "content":
|
91 |
)
|
92 |
|
93 |
-
#
|
94 |
-
# check if it has a 'results' attribute, and then access its categories
|
95 |
if hasattr(response, 'results') and response.results:
|
96 |
categories = response.results[0].categories
|
97 |
-
# Check if harmful categories are present
|
98 |
if categories.get("violence_and_threats", False) or \
|
99 |
categories.get("hate_and_discrimination", False) or \
|
100 |
categories.get("dangerous_and_criminal_content", False) or \
|
101 |
categories.get("selfharm", False):
|
102 |
return "OutOfScope"
|
103 |
|
104 |
-
return
|
105 |
|
106 |
|
107 |
# Function to build or load the vector store from CSV data
|
@@ -173,7 +173,7 @@ def merge_responses(kb_answer: str, web_answer: str) -> str:
|
|
173 |
|
174 |
# Orchestrate the entire workflow
|
175 |
def run_pipeline(query: str) -> str:
|
176 |
-
# Moderate the query for harmful content
|
177 |
moderated_query = moderate_text(query)
|
178 |
if moderated_query == "OutOfScope":
|
179 |
return "Sorry, this query contains harmful or inappropriate content."
|
@@ -207,7 +207,7 @@ def run_pipeline(query: str) -> str:
|
|
207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
208 |
return final_refusal.strip()
|
209 |
|
210 |
-
# Initialize chains
|
211 |
classification_chain = get_classification_chain()
|
212 |
refusal_chain = get_refusal_chain()
|
213 |
tailor_chain = get_tailor_chain()
|
@@ -224,3 +224,7 @@ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
|
|
224 |
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
|
225 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
226 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
|
|
|
|
|
|
|
|
|
10 |
from langchain.vectorstores import FAISS
|
11 |
from langchain.chains import RetrievalQA
|
12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
13 |
+
from pydantic import BaseModel, ValidationError, validator
|
14 |
from mistralai import Mistral
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
|
|
|
25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
26 |
client = Mistral(api_key=mistral_api_key)
|
27 |
|
28 |
+
# Initialize Pydantic AI Agent (for text validation)
|
29 |
+
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
30 |
+
|
31 |
# Load spaCy model for NER and download it if not already installed
|
32 |
def install_spacy_model():
|
33 |
try:
|
|
|
56 |
break
|
57 |
return main_topic if main_topic else "this topic"
|
58 |
|
59 |
+
# Pydantic model to handle string input validation
|
60 |
+
class QueryInput(BaseModel):
|
61 |
+
query: str
|
62 |
+
|
63 |
+
# Validator to ensure the query is always a string
|
64 |
+
@validator('query')
|
65 |
+
def check_query_is_string(cls, v):
|
66 |
+
if not isinstance(v, str):
|
67 |
+
raise ValueError("Query must be a valid string.")
|
68 |
+
return v
|
69 |
+
|
70 |
# Function to classify query based on wellness topics
|
71 |
def classify_query(query: str) -> str:
|
72 |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
|
|
|
77 |
classification = class_result.get("text", "").strip()
|
78 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
79 |
|
80 |
+
# Function to moderate text using Mistral moderation API (sync version)
|
81 |
+
def moderate_text(query: str) -> str:
|
|
|
|
|
|
|
|
|
82 |
try:
|
83 |
+
# Use Pydantic to validate text input
|
84 |
+
query_input = QueryInput(query=query) # This will validate that the query is a string
|
|
|
85 |
except ValidationError as e:
|
86 |
print(f"Error validating text: {e}")
|
87 |
return "Invalid text format."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Call the Mistral moderation API
|
90 |
response = client.classifiers.moderate_chat(
|
91 |
model="mistral-moderation-latest",
|
92 |
+
inputs=[{"role": "user", "content": query}]
|
93 |
)
|
94 |
|
95 |
+
# Check if harmful categories are present in the response
|
|
|
96 |
if hasattr(response, 'results') and response.results:
|
97 |
categories = response.results[0].categories
|
|
|
98 |
if categories.get("violence_and_threats", False) or \
|
99 |
categories.get("hate_and_discrimination", False) or \
|
100 |
categories.get("dangerous_and_criminal_content", False) or \
|
101 |
categories.get("selfharm", False):
|
102 |
return "OutOfScope"
|
103 |
|
104 |
+
return query
|
105 |
|
106 |
|
107 |
# Function to build or load the vector store from CSV data
|
|
|
173 |
|
174 |
# Orchestrate the entire workflow
|
175 |
def run_pipeline(query: str) -> str:
|
176 |
+
# Moderate the query for harmful content
|
177 |
moderated_query = moderate_text(query)
|
178 |
if moderated_query == "OutOfScope":
|
179 |
return "Sorry, this query contains harmful or inappropriate content."
|
|
|
207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
208 |
return final_refusal.strip()
|
209 |
|
210 |
+
# Initialize chains
|
211 |
classification_chain = get_classification_chain()
|
212 |
refusal_chain = get_refusal_chain()
|
213 |
tailor_chain = get_tailor_chain()
|
|
|
224 |
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
|
225 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
226 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
227 |
+
|
228 |
+
# Function to wrap up and run the chain
|
229 |
+
def run_with_chain(query: str) -> str:
|
230 |
+
return run_pipeline(query)
|