Phoenix21 commited on
Commit
756269e
·
verified ·
1 Parent(s): 263ad5f

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +28 -24
pipeline.py CHANGED
@@ -10,7 +10,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
- from pydantic import BaseModel, ValidationError # Import Pydantic for text validation
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
 
@@ -25,6 +25,9 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
 
 
 
28
  # Load spaCy model for NER and download it if not already installed
29
  def install_spacy_model():
30
  try:
@@ -53,6 +56,17 @@ def extract_main_topic(query: str) -> str:
53
  break
54
  return main_topic if main_topic else "this topic"
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Function to classify query based on wellness topics
57
  def classify_query(query: str) -> str:
58
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
@@ -63,45 +77,31 @@ def classify_query(query: str) -> str:
63
  classification = class_result.get("text", "").strip()
64
  return classification if classification != "OutOfScope" else "OutOfScope"
65
 
66
- # Pydantic model for text validation
67
- class TextInputModel(BaseModel):
68
- text: str
69
-
70
- # Function to validate the text input using Pydantic
71
- def validate_text(query: str) -> str:
72
  try:
73
- # Attempt to validate the query as a text input
74
- TextInputModel(text=query)
75
- return query
76
  except ValidationError as e:
77
  print(f"Error validating text: {e}")
78
  return "Invalid text format."
79
-
80
- # Function to moderate text using Mistral moderation API (synchronous version)
81
- def moderate_text(query: str) -> str:
82
- # Validate the text using Pydantic
83
- validated_text = validate_text(query)
84
- if validated_text == "Invalid text format.":
85
- return validated_text
86
 
87
  # Call the Mistral moderation API
88
  response = client.classifiers.moderate_chat(
89
  model="mistral-moderation-latest",
90
- inputs=[{"role": "user", "content": validated_text}]
91
  )
92
 
93
- # Assuming the response is an object of type 'ClassificationResponse',
94
- # check if it has a 'results' attribute, and then access its categories
95
  if hasattr(response, 'results') and response.results:
96
  categories = response.results[0].categories
97
- # Check if harmful categories are present
98
  if categories.get("violence_and_threats", False) or \
99
  categories.get("hate_and_discrimination", False) or \
100
  categories.get("dangerous_and_criminal_content", False) or \
101
  categories.get("selfharm", False):
102
  return "OutOfScope"
103
 
104
- return validated_text
105
 
106
 
107
  # Function to build or load the vector store from CSV data
@@ -173,7 +173,7 @@ def merge_responses(kb_answer: str, web_answer: str) -> str:
173
 
174
  # Orchestrate the entire workflow
175
  def run_pipeline(query: str) -> str:
176
- # Moderate the query for harmful content (sync)
177
  moderated_query = moderate_text(query)
178
  if moderated_query == "OutOfScope":
179
  return "Sorry, this query contains harmful or inappropriate content."
@@ -207,7 +207,7 @@ def run_pipeline(query: str) -> str:
207
  final_refusal = tailor_chain.run({"response": refusal_text})
208
  return final_refusal.strip()
209
 
210
- # Initialize chains here
211
  classification_chain = get_classification_chain()
212
  refusal_chain = get_refusal_chain()
213
  tailor_chain = get_tailor_chain()
@@ -224,3 +224,7 @@ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
224
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
225
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
226
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
 
 
 
 
 
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
+ from pydantic import BaseModel, ValidationError, validator
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
 
 
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
+ # Initialize Pydantic AI Agent (for text validation)
29
+ pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
30
+
31
  # Load spaCy model for NER and download it if not already installed
32
  def install_spacy_model():
33
  try:
 
56
  break
57
  return main_topic if main_topic else "this topic"
58
 
59
+ # Pydantic model to handle string input validation
60
+ class QueryInput(BaseModel):
61
+ query: str
62
+
63
+ # Validator to ensure the query is always a string
64
+ @validator('query')
65
+ def check_query_is_string(cls, v):
66
+ if not isinstance(v, str):
67
+ raise ValueError("Query must be a valid string.")
68
+ return v
69
+
70
  # Function to classify query based on wellness topics
71
  def classify_query(query: str) -> str:
72
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
 
77
  classification = class_result.get("text", "").strip()
78
  return classification if classification != "OutOfScope" else "OutOfScope"
79
 
80
+ # Function to moderate text using Mistral moderation API (sync version)
81
+ def moderate_text(query: str) -> str:
 
 
 
 
82
  try:
83
+ # Use Pydantic to validate text input
84
+ query_input = QueryInput(query=query) # This will validate that the query is a string
 
85
  except ValidationError as e:
86
  print(f"Error validating text: {e}")
87
  return "Invalid text format."
 
 
 
 
 
 
 
88
 
89
  # Call the Mistral moderation API
90
  response = client.classifiers.moderate_chat(
91
  model="mistral-moderation-latest",
92
+ inputs=[{"role": "user", "content": query}]
93
  )
94
 
95
+ # Check if harmful categories are present in the response
 
96
  if hasattr(response, 'results') and response.results:
97
  categories = response.results[0].categories
 
98
  if categories.get("violence_and_threats", False) or \
99
  categories.get("hate_and_discrimination", False) or \
100
  categories.get("dangerous_and_criminal_content", False) or \
101
  categories.get("selfharm", False):
102
  return "OutOfScope"
103
 
104
+ return query
105
 
106
 
107
  # Function to build or load the vector store from CSV data
 
173
 
174
  # Orchestrate the entire workflow
175
  def run_pipeline(query: str) -> str:
176
+ # Moderate the query for harmful content
177
  moderated_query = moderate_text(query)
178
  if moderated_query == "OutOfScope":
179
  return "Sorry, this query contains harmful or inappropriate content."
 
207
  final_refusal = tailor_chain.run({"response": refusal_text})
208
  return final_refusal.strip()
209
 
210
+ # Initialize chains
211
  classification_chain = get_classification_chain()
212
  refusal_chain = get_refusal_chain()
213
  tailor_chain = get_tailor_chain()
 
224
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
225
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
226
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
227
+
228
+ # Function to wrap up and run the chain
229
+ def run_with_chain(query: str) -> str:
230
+ return run_pipeline(query)