Phoenix21 commited on
Commit
a5cb58f
·
verified ·
1 Parent(s): 0aef3aa

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -4
pipeline.py CHANGED
@@ -68,16 +68,16 @@ def classify_query(query: str) -> str:
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
  # Function to moderate text using Mistral moderation API (async version)
71
- async def moderate_text(query: str) -> str:
72
  try:
73
  # Use Pydantic AI to validate the text
74
- await pydantic_agent.run(query) # Use async run for Pydantic validation
75
  except Exception as e:
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
78
 
79
  # Call the Mistral moderation API
80
- response = await client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
  inputs=[{"role": "user", "content": query}]
83
  )
@@ -201,7 +201,8 @@ async def run_async_pipeline(query: str) -> str:
201
 
202
  # Run the pipeline with the event loop
203
  def run_with_chain(query: str) -> str:
204
- return asyncio.run(run_async_pipeline(query))
 
205
 
206
  # Initialize chains here
207
  classification_chain = get_classification_chain()
 
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
  # Function to moderate text using Mistral moderation API (async version)
71
+ def moderate_text(query: str) -> str:
72
  try:
73
  # Use Pydantic AI to validate the text
74
+ pydantic_agent.run_sync(query) # Use sync run for Pydantic validation
75
  except Exception as e:
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
78
 
79
  # Call the Mistral moderation API
80
+ response = client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
  inputs=[{"role": "user", "content": query}]
83
  )
 
201
 
202
  # Run the pipeline with the event loop
203
  def run_with_chain(query: str) -> str:
204
+ loop = asyncio.get_event_loop()
205
+ return loop.run_until_complete(run_async_pipeline(query))
206
 
207
  # Initialize chains here
208
  classification_chain = get_classification_chain()