harshalmore31 commited on
Commit
9d7d2e0
·
1 Parent(s): 6fba44c

Add strongly-typed models for function-calling arguments to enhance type safety and validation

Browse files
Files changed (1) hide show
  1. mai_dx/main.py +45 -2
mai_dx/main.py CHANGED
@@ -35,10 +35,10 @@ import sys
35
  import time
36
  from dataclasses import dataclass, field
37
  from enum import Enum
38
- from typing import Any, Dict, List, Union, Literal
39
 
40
  from loguru import logger
41
- from pydantic import BaseModel, Field
42
  from swarms import Agent, Conversation
43
  from dotenv import load_dotenv
44
 
@@ -241,6 +241,36 @@ class Action(BaseModel):
241
  )
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # --- Main Orchestrator Class ---
245
 
246
 
@@ -1660,6 +1690,12 @@ CURRENT STATE:
1660
  if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
1661
  structured_data = self._extract_function_call_output(hypothesis_response)
1662
 
 
 
 
 
 
 
1663
  # Check if we got structured differential data
1664
  if "differential_diagnoses" in structured_data:
1665
  # Update case state with structured data
@@ -2524,6 +2560,13 @@ Please try again and ensure you call the function correctly.
2524
  # Try to extract function call output
2525
  action_dict = self._extract_function_call_output(response)
2526
 
 
 
 
 
 
 
 
2527
  # Check if we got a valid response (not a fallback)
2528
  if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
2529
  logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
 
35
  import time
36
  from dataclasses import dataclass, field
37
  from enum import Enum
38
+ from typing import Any, Dict, List, Union, Literal, Optional
39
 
40
  from loguru import logger
41
+ from pydantic import BaseModel, Field, ValidationError
42
  from swarms import Agent, Conversation
43
  from dotenv import load_dotenv
44
 
 
241
  )
242
 
243
 
244
+ # ------------------------------------------------------------------
245
+ # Strongly-typed models for function-calling arguments (type safety)
246
+ # ------------------------------------------------------------------
247
+
248
+
249
+ class ConsensusArguments(BaseModel):
250
+ """Typed model for the `make_consensus_decision` function call."""
251
+
252
+ action_type: Literal["ask", "test", "diagnose"]
253
+ content: Union[str, List[str]]
254
+ reasoning: str
255
+
256
+
257
+ class DifferentialDiagnosisItem(BaseModel):
258
+ """Single differential diagnosis item returned by Dr. Hypothesis."""
259
+
260
+ diagnosis: str
261
+ probability: float
262
+ rationale: str
263
+
264
+
265
+ class HypothesisArguments(BaseModel):
266
+ """Typed model for the `update_differential_diagnosis` function call."""
267
+
268
+ summary: str
269
+ differential_diagnoses: List[DifferentialDiagnosisItem]
270
+ key_evidence: str
271
+ contradictory_evidence: Optional[str] = None
272
+
273
+
274
  # --- Main Orchestrator Class ---
275
 
276
 
 
1690
  if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
1691
  structured_data = self._extract_function_call_output(hypothesis_response)
1692
 
1693
+ # Validate the structured data using the HypothesisArguments schema
1694
+ try:
1695
+ _ = HypothesisArguments(**structured_data)
1696
+ except ValidationError as e:
1697
+ logger.warning(f"HypothesisArguments validation failed: {e}")
1698
+
1699
  # Check if we got structured differential data
1700
  if "differential_diagnoses" in structured_data:
1701
  # Update case state with structured data
 
2560
  # Try to extract function call output
2561
  action_dict = self._extract_function_call_output(response)
2562
 
2563
+ # Validate and enforce schema using ConsensusArguments for type safety
2564
+ try:
2565
+ validated_args = ConsensusArguments(**action_dict)
2566
+ action_dict = validated_args.dict()
2567
+ except ValidationError as e:
2568
+ logger.warning(f"ConsensusArguments validation failed: {e}")
2569
+
2570
  # Check if we got a valid response (not a fallback)
2571
  if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
2572
  logger.debug(f"Consensus function call successful on attempt {attempt + 1}")