Spaces:
Sleeping
Sleeping
Commit
·
9d7d2e0
1
Parent(s):
6fba44c
Add strongly-typed models for function-calling arguments to enhance type safety and validation
Browse files- mai_dx/main.py +45 -2
mai_dx/main.py
CHANGED
|
@@ -35,10 +35,10 @@ import sys
|
|
| 35 |
import time
|
| 36 |
from dataclasses import dataclass, field
|
| 37 |
from enum import Enum
|
| 38 |
-
from typing import Any, Dict, List, Union, Literal
|
| 39 |
|
| 40 |
from loguru import logger
|
| 41 |
-
from pydantic import BaseModel, Field
|
| 42 |
from swarms import Agent, Conversation
|
| 43 |
from dotenv import load_dotenv
|
| 44 |
|
|
@@ -241,6 +241,36 @@ class Action(BaseModel):
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
# --- Main Orchestrator Class ---
|
| 245 |
|
| 246 |
|
|
@@ -1660,6 +1690,12 @@ CURRENT STATE:
|
|
| 1660 |
if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
|
| 1661 |
structured_data = self._extract_function_call_output(hypothesis_response)
|
| 1662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1663 |
# Check if we got structured differential data
|
| 1664 |
if "differential_diagnoses" in structured_data:
|
| 1665 |
# Update case state with structured data
|
|
@@ -2524,6 +2560,13 @@ Please try again and ensure you call the function correctly.
|
|
| 2524 |
# Try to extract function call output
|
| 2525 |
action_dict = self._extract_function_call_output(response)
|
| 2526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2527 |
# Check if we got a valid response (not a fallback)
|
| 2528 |
if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
|
| 2529 |
logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
|
|
|
|
| 35 |
import time
|
| 36 |
from dataclasses import dataclass, field
|
| 37 |
from enum import Enum
|
| 38 |
+
from typing import Any, Dict, List, Union, Literal, Optional
|
| 39 |
|
| 40 |
from loguru import logger
|
| 41 |
+
from pydantic import BaseModel, Field, ValidationError
|
| 42 |
from swarms import Agent, Conversation
|
| 43 |
from dotenv import load_dotenv
|
| 44 |
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
| 244 |
+
# ------------------------------------------------------------------
|
| 245 |
+
# Strongly-typed models for function-calling arguments (type safety)
|
| 246 |
+
# ------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ConsensusArguments(BaseModel):
|
| 250 |
+
"""Typed model for the `make_consensus_decision` function call."""
|
| 251 |
+
|
| 252 |
+
action_type: Literal["ask", "test", "diagnose"]
|
| 253 |
+
content: Union[str, List[str]]
|
| 254 |
+
reasoning: str
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class DifferentialDiagnosisItem(BaseModel):
|
| 258 |
+
"""Single differential diagnosis item returned by Dr. Hypothesis."""
|
| 259 |
+
|
| 260 |
+
diagnosis: str
|
| 261 |
+
probability: float
|
| 262 |
+
rationale: str
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class HypothesisArguments(BaseModel):
|
| 266 |
+
"""Typed model for the `update_differential_diagnosis` function call."""
|
| 267 |
+
|
| 268 |
+
summary: str
|
| 269 |
+
differential_diagnoses: List[DifferentialDiagnosisItem]
|
| 270 |
+
key_evidence: str
|
| 271 |
+
contradictory_evidence: Optional[str] = None
|
| 272 |
+
|
| 273 |
+
|
| 274 |
# --- Main Orchestrator Class ---
|
| 275 |
|
| 276 |
|
|
|
|
| 1690 |
if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
|
| 1691 |
structured_data = self._extract_function_call_output(hypothesis_response)
|
| 1692 |
|
| 1693 |
+
# Validate the structured data using the HypothesisArguments schema
|
| 1694 |
+
try:
|
| 1695 |
+
_ = HypothesisArguments(**structured_data)
|
| 1696 |
+
except ValidationError as e:
|
| 1697 |
+
logger.warning(f"HypothesisArguments validation failed: {e}")
|
| 1698 |
+
|
| 1699 |
# Check if we got structured differential data
|
| 1700 |
if "differential_diagnoses" in structured_data:
|
| 1701 |
# Update case state with structured data
|
|
|
|
| 2560 |
# Try to extract function call output
|
| 2561 |
action_dict = self._extract_function_call_output(response)
|
| 2562 |
|
| 2563 |
+
# Validate and enforce schema using ConsensusArguments for type safety
|
| 2564 |
+
try:
|
| 2565 |
+
validated_args = ConsensusArguments(**action_dict)
|
| 2566 |
+
action_dict = validated_args.dict()
|
| 2567 |
+
except ValidationError as e:
|
| 2568 |
+
logger.warning(f"ConsensusArguments validation failed: {e}")
|
| 2569 |
+
|
| 2570 |
# Check if we got a valid response (not a fallback)
|
| 2571 |
if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
|
| 2572 |
logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
|