Spaces:
Sleeping
Sleeping
Commit
·
9d7d2e0
1
Parent(s):
6fba44c
Add strongly-typed models for function-calling arguments to enhance type safety and validation
Browse files- mai_dx/main.py +45 -2
mai_dx/main.py
CHANGED
@@ -35,10 +35,10 @@ import sys
|
|
35 |
import time
|
36 |
from dataclasses import dataclass, field
|
37 |
from enum import Enum
|
38 |
-
from typing import Any, Dict, List, Union, Literal
|
39 |
|
40 |
from loguru import logger
|
41 |
-
from pydantic import BaseModel, Field
|
42 |
from swarms import Agent, Conversation
|
43 |
from dotenv import load_dotenv
|
44 |
|
@@ -241,6 +241,36 @@ class Action(BaseModel):
|
|
241 |
)
|
242 |
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
# --- Main Orchestrator Class ---
|
245 |
|
246 |
|
@@ -1660,6 +1690,12 @@ CURRENT STATE:
|
|
1660 |
if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
|
1661 |
structured_data = self._extract_function_call_output(hypothesis_response)
|
1662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1663 |
# Check if we got structured differential data
|
1664 |
if "differential_diagnoses" in structured_data:
|
1665 |
# Update case state with structured data
|
@@ -2524,6 +2560,13 @@ Please try again and ensure you call the function correctly.
|
|
2524 |
# Try to extract function call output
|
2525 |
action_dict = self._extract_function_call_output(response)
|
2526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2527 |
# Check if we got a valid response (not a fallback)
|
2528 |
if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
|
2529 |
logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
|
|
|
35 |
import time
|
36 |
from dataclasses import dataclass, field
|
37 |
from enum import Enum
|
38 |
+
from typing import Any, Dict, List, Union, Literal, Optional
|
39 |
|
40 |
from loguru import logger
|
41 |
+
from pydantic import BaseModel, Field, ValidationError
|
42 |
from swarms import Agent, Conversation
|
43 |
from dotenv import load_dotenv
|
44 |
|
|
|
241 |
)
|
242 |
|
243 |
|
244 |
+
# ------------------------------------------------------------------
|
245 |
+
# Strongly-typed models for function-calling arguments (type safety)
|
246 |
+
# ------------------------------------------------------------------
|
247 |
+
|
248 |
+
|
249 |
+
class ConsensusArguments(BaseModel):
|
250 |
+
"""Typed model for the `make_consensus_decision` function call."""
|
251 |
+
|
252 |
+
action_type: Literal["ask", "test", "diagnose"]
|
253 |
+
content: Union[str, List[str]]
|
254 |
+
reasoning: str
|
255 |
+
|
256 |
+
|
257 |
+
class DifferentialDiagnosisItem(BaseModel):
|
258 |
+
"""Single differential diagnosis item returned by Dr. Hypothesis."""
|
259 |
+
|
260 |
+
diagnosis: str
|
261 |
+
probability: float
|
262 |
+
rationale: str
|
263 |
+
|
264 |
+
|
265 |
+
class HypothesisArguments(BaseModel):
|
266 |
+
"""Typed model for the `update_differential_diagnosis` function call."""
|
267 |
+
|
268 |
+
summary: str
|
269 |
+
differential_diagnoses: List[DifferentialDiagnosisItem]
|
270 |
+
key_evidence: str
|
271 |
+
contradictory_evidence: Optional[str] = None
|
272 |
+
|
273 |
+
|
274 |
# --- Main Orchestrator Class ---
|
275 |
|
276 |
|
|
|
1690 |
if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
|
1691 |
structured_data = self._extract_function_call_output(hypothesis_response)
|
1692 |
|
1693 |
+
# Validate the structured data using the HypothesisArguments schema
|
1694 |
+
try:
|
1695 |
+
_ = HypothesisArguments(**structured_data)
|
1696 |
+
except ValidationError as e:
|
1697 |
+
logger.warning(f"HypothesisArguments validation failed: {e}")
|
1698 |
+
|
1699 |
# Check if we got structured differential data
|
1700 |
if "differential_diagnoses" in structured_data:
|
1701 |
# Update case state with structured data
|
|
|
2560 |
# Try to extract function call output
|
2561 |
action_dict = self._extract_function_call_output(response)
|
2562 |
|
2563 |
+
# Validate and enforce schema using ConsensusArguments for type safety
|
2564 |
+
try:
|
2565 |
+
validated_args = ConsensusArguments(**action_dict)
|
2566 |
+
action_dict = validated_args.dict()
|
2567 |
+
except ValidationError as e:
|
2568 |
+
logger.warning(f"ConsensusArguments validation failed: {e}")
|
2569 |
+
|
2570 |
# Check if we got a valid response (not a fallback)
|
2571 |
if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
|
2572 |
logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
|