Spaces:
Sleeping
Sleeping
Merge pull request #7 from harshalmore31/main
Browse filesFix. the context issue, and loop issue in decision making handling the max_token properly as per each roles for an proper panel discussion
- DOCS.md +1 -0
- README.md +1 -0
- mai_dx/main.py +1202 -215
DOCS.md
CHANGED
@@ -33,6 +33,7 @@
|
|
33 |
- **Robust Error Handling**: Comprehensive exception management and fallback mechanisms
|
34 |
- **Beautiful Logging**: Structured logging with Loguru for debugging and monitoring
|
35 |
- **Type Safety**: Full Pydantic models and type hints throughout
|
|
|
36 |
|
37 |
## π Architecture
|
38 |
|
|
|
33 |
- **Robust Error Handling**: Comprehensive exception management and fallback mechanisms
|
34 |
- **Beautiful Logging**: Structured logging with Loguru for debugging and monitoring
|
35 |
- **Type Safety**: Full Pydantic models and type hints throughout
|
36 |
+
- **Token-Optimized Prompts**: Compact, role-specific directives minimize token usage while preserving diagnostic quality
|
37 |
|
38 |
## π Architecture
|
39 |
|
README.md
CHANGED
@@ -15,6 +15,7 @@ MAI-DxO (MAI Diagnostic Orchestrator) is a sophisticated AI-powered diagnostic s
|
|
15 |
- **Cost Tracking**: Real-time budget monitoring with costs for 25+ medical tests.
|
16 |
- **Clinical Evaluation**: 5-point accuracy scoring with detailed feedback.
|
17 |
- **Model Agnostic**: Works with GPT, Gemini, Claude, and other leading LLMs.
|
|
|
18 |
|
19 |
## π Quick Start
|
20 |
|
|
|
15 |
- **Cost Tracking**: Real-time budget monitoring with costs for 25+ medical tests.
|
16 |
- **Clinical Evaluation**: 5-point accuracy scoring with detailed feedback.
|
17 |
- **Model Agnostic**: Works with GPT, Gemini, Claude, and other leading LLMs.
|
18 |
+
- **Token-Optimized Prompts**: Ultra-compact role prompts reduce token usage and latency without sacrificing reasoning quality.
|
19 |
|
20 |
## π Quick Start
|
21 |
|
mai_dx/main.py
CHANGED
@@ -18,7 +18,7 @@ Key Features:
|
|
18 |
|
19 |
Example Usage:
|
20 |
# Standard MAI-DxO usage
|
21 |
-
orchestrator = MaiDxOrchestrator(model_name="gpt-
|
22 |
result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
|
23 |
|
24 |
# Budget-constrained variant
|
@@ -33,13 +33,16 @@ import os
|
|
33 |
import json
|
34 |
import sys
|
35 |
import time
|
36 |
-
from dataclasses import dataclass
|
37 |
from enum import Enum
|
38 |
-
from typing import Any, Dict, List, Union, Literal
|
39 |
|
40 |
from loguru import logger
|
41 |
-
from pydantic import BaseModel, Field
|
42 |
from swarms import Agent, Conversation
|
|
|
|
|
|
|
43 |
|
44 |
# Configure Loguru with beautiful formatting and features
|
45 |
logger.remove() # Remove default handler
|
@@ -91,6 +94,125 @@ class AgentRole(Enum):
|
|
91 |
JUDGE = "Judge"
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
@dataclass
|
95 |
class DiagnosisResult:
|
96 |
"""Stores the final result of a diagnostic session."""
|
@@ -119,6 +241,36 @@ class Action(BaseModel):
|
|
119 |
)
|
120 |
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# --- Main Orchestrator Class ---
|
123 |
|
124 |
|
@@ -127,19 +279,22 @@ class MaiDxOrchestrator:
|
|
127 |
Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
|
128 |
This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
|
129 |
evaluates the final diagnosis, and tracks costs.
|
|
|
|
|
130 |
"""
|
131 |
|
132 |
def __init__(
|
133 |
self,
|
134 |
-
model_name: str = "gpt-
|
135 |
max_iterations: int = 10,
|
136 |
initial_budget: int = 10000,
|
137 |
mode: str = "no_budget", # "instant", "question_only", "budgeted", "no_budget", "ensemble"
|
138 |
physician_visit_cost: int = 300,
|
139 |
enable_budget_tracking: bool = False,
|
|
|
140 |
):
|
141 |
"""
|
142 |
-
Initializes the MAI-DxO system.
|
143 |
|
144 |
Args:
|
145 |
model_name (str): The language model to be used by all agents.
|
@@ -148,6 +303,7 @@ class MaiDxOrchestrator:
|
|
148 |
mode (str): The operational mode of MAI-DxO.
|
149 |
physician_visit_cost (int): Cost per physician visit.
|
150 |
enable_budget_tracking (bool): Whether to enable budget tracking.
|
|
|
151 |
"""
|
152 |
self.model_name = model_name
|
153 |
self.max_iterations = max_iterations
|
@@ -156,11 +312,20 @@ class MaiDxOrchestrator:
|
|
156 |
self.physician_visit_cost = physician_visit_cost
|
157 |
self.enable_budget_tracking = enable_budget_tracking
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
self.cumulative_cost = 0
|
160 |
self.differential_diagnosis = "Not yet formulated."
|
161 |
self.conversation = Conversation(
|
162 |
time_enabled=True, autosave=False, save_enabled=False
|
163 |
)
|
|
|
|
|
|
|
164 |
|
165 |
# Enhanced cost model based on the paper's methodology
|
166 |
self.test_cost_db = {
|
@@ -198,27 +363,290 @@ class MaiDxOrchestrator:
|
|
198 |
f"π₯ MAI Diagnostic Orchestrator initialized successfully in '{mode}' mode with budget ${initial_budget:,}"
|
199 |
)
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
def _init_agents(self) -> None:
|
202 |
"""Initializes all required agents with their specific roles and prompts."""
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
logger.info(
|
217 |
f"π₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
|
218 |
)
|
219 |
|
220 |
-
def
|
221 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
prompts = {
|
223 |
AgentRole.HYPOTHESIS: (
|
224 |
"""
|
@@ -525,9 +953,25 @@ class MaiDxOrchestrator:
|
|
525 |
}
|
526 |
return prompts[role]
|
527 |
|
528 |
-
def _parse_json_response(self, response: str) -> Dict[str, Any]:
|
529 |
-
"""Safely parses a JSON string
|
530 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
# Extract the actual response content from the agent response
|
532 |
if isinstance(response, str):
|
533 |
# Handle markdown-formatted JSON
|
@@ -576,20 +1020,54 @@ class MaiDxOrchestrator:
|
|
576 |
# Try to extract JSON from text that might contain other content
|
577 |
import re
|
578 |
|
579 |
-
# Look for JSON pattern in the text
|
580 |
-
json_pattern = r
|
581 |
-
matches = re.findall(
|
582 |
-
json_pattern, response, re.DOTALL
|
583 |
-
)
|
584 |
|
585 |
for match in matches:
|
586 |
try:
|
587 |
-
|
|
|
|
|
|
|
588 |
except json.JSONDecodeError:
|
589 |
continue
|
590 |
|
591 |
# Direct parsing attempt as fallback
|
592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
except (
|
595 |
json.JSONDecodeError,
|
@@ -600,14 +1078,75 @@ class MaiDxOrchestrator:
|
|
600 |
logger.debug(
|
601 |
f"Response content: {response[:500]}..."
|
602 |
) # Log first 500 chars
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
|
612 |
def _estimate_cost(self, tests: Union[List[str], str]) -> int:
|
613 |
"""Estimates the cost of diagnostic tests."""
|
@@ -696,191 +1235,137 @@ class MaiDxOrchestrator:
|
|
696 |
|
697 |
return cost
|
698 |
|
699 |
-
def _run_panel_deliberation(self) -> Action:
|
700 |
-
"""Orchestrates one round of debate among the virtual panel
|
701 |
logger.info(
|
702 |
"π©Ί Virtual medical panel deliberation commenced - analyzing patient case"
|
703 |
)
|
704 |
logger.debug(
|
705 |
"Panel members: Dr. Hypothesis, Dr. Test-Chooser, Dr. Challenger, Dr. Stewardship, Dr. Checklist"
|
706 |
)
|
707 |
-
panel_conversation = Conversation(
|
708 |
-
time_enabled=True, autosave=False, save_enabled=False
|
709 |
-
)
|
710 |
|
711 |
-
#
|
712 |
-
|
|
|
|
|
|
|
713 |
budget_status = (
|
714 |
"EXCEEDED"
|
715 |
if remaining_budget < 0
|
716 |
else f"${remaining_budget:,}"
|
717 |
)
|
718 |
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
"""
|
736 |
-
panel_conversation.add("System", panel_context)
|
737 |
|
738 |
# Check mode-specific constraints
|
739 |
if self.mode == "instant":
|
740 |
# For instant mode, skip deliberation and go straight to diagnosis
|
741 |
action_dict = {
|
742 |
"action_type": "diagnose",
|
743 |
-
"content": (
|
744 |
-
self.differential_diagnosis.split("\n")[0]
|
745 |
-
if "\n" in self.differential_diagnosis
|
746 |
-
else self.differential_diagnosis
|
747 |
-
),
|
748 |
"reasoning": (
|
749 |
"Instant diagnosis mode - providing immediate assessment based on initial presentation"
|
750 |
),
|
751 |
}
|
752 |
return Action(**action_dict)
|
753 |
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
|
|
|
|
|
|
|
|
758 |
|
759 |
-
#
|
|
|
|
|
|
|
|
|
760 |
try:
|
761 |
# Dr. Hypothesis - Differential diagnosis and probability assessment
|
762 |
-
logger.info(
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
panel_conversation.get_str()
|
767 |
-
)
|
768 |
-
self.differential_diagnosis = (
|
769 |
-
hypothesis # Update main state
|
770 |
-
)
|
771 |
-
panel_conversation.add(
|
772 |
-
self.agents[AgentRole.HYPOTHESIS].agent_name,
|
773 |
-
hypothesis,
|
774 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
|
776 |
# Dr. Test-Chooser - Information value optimization
|
777 |
-
logger.info(
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
panel_conversation.add(
|
784 |
-
self.agents[AgentRole.TEST_CHOOSER].agent_name,
|
785 |
-
test_choices,
|
786 |
)
|
787 |
|
788 |
# Dr. Challenger - Bias identification and alternative hypotheses
|
789 |
-
logger.info(
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
panel_conversation.get_str()
|
794 |
-
)
|
795 |
-
panel_conversation.add(
|
796 |
-
self.agents[AgentRole.CHALLENGER].agent_name,
|
797 |
-
challenges,
|
798 |
)
|
799 |
|
800 |
# Dr. Stewardship - Cost-effectiveness analysis
|
801 |
-
logger.info(
|
802 |
-
|
803 |
-
)
|
804 |
-
stewardship_context = panel_conversation.get_str()
|
805 |
if self.enable_budget_tracking:
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
)
|
810 |
-
panel_conversation.add(
|
811 |
-
self.agents[AgentRole.STEWARDSHIP].agent_name,
|
812 |
-
stewardship_rec,
|
813 |
)
|
814 |
|
815 |
# Dr. Checklist - Quality assurance
|
816 |
-
logger.info(
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
panel_conversation.get_str()
|
821 |
-
)
|
822 |
-
panel_conversation.add(
|
823 |
-
self.agents[AgentRole.CHECKLIST].agent_name,
|
824 |
-
checklist_rep,
|
825 |
-
)
|
826 |
-
|
827 |
-
# Consensus Coordinator - Final decision synthesis
|
828 |
-
logger.info(
|
829 |
-
"π€ Consensus Coordinator synthesizing panel decision..."
|
830 |
)
|
831 |
-
consensus_context = panel_conversation.get_str()
|
832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
# Add mode-specific constraints to consensus
|
834 |
if self.mode == "budgeted" and remaining_budget <= 0:
|
835 |
-
|
836 |
-
|
837 |
-
consensus_response = self.agents[AgentRole.CONSENSUS].run(
|
838 |
-
consensus_context
|
839 |
-
)
|
840 |
-
logger.debug(
|
841 |
-
f"Raw consensus response: {consensus_response}"
|
842 |
-
)
|
843 |
-
|
844 |
-
# Extract the actual text content from agent response
|
845 |
-
if hasattr(consensus_response, "content"):
|
846 |
-
response_text = consensus_response.content
|
847 |
-
elif isinstance(consensus_response, str):
|
848 |
-
response_text = consensus_response
|
849 |
-
else:
|
850 |
-
response_text = str(consensus_response)
|
851 |
|
852 |
-
|
|
|
853 |
|
854 |
# Validate action based on mode constraints
|
855 |
action = Action(**action_dict)
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
):
|
860 |
-
logger.warning(
|
861 |
-
"Test ordering attempted in question-only mode, converting to ask action"
|
862 |
-
)
|
863 |
-
action.action_type = "ask"
|
864 |
-
action.content = "Can you provide more details about the patient's symptoms and history?"
|
865 |
-
action.reasoning = (
|
866 |
-
"Mode constraint: question-only mode active"
|
867 |
-
)
|
868 |
-
|
869 |
-
if (
|
870 |
-
self.mode == "budgeted"
|
871 |
-
and action.action_type == "test"
|
872 |
-
and remaining_budget <= 0
|
873 |
-
):
|
874 |
-
logger.warning(
|
875 |
-
"Test ordering attempted with insufficient budget, converting to diagnose action"
|
876 |
-
)
|
877 |
-
action.action_type = "diagnose"
|
878 |
-
action.content = (
|
879 |
-
self.differential_diagnosis.split("\n")[0]
|
880 |
-
if "\n" in self.differential_diagnosis
|
881 |
-
else self.differential_diagnosis
|
882 |
-
)
|
883 |
-
action.reasoning = "Budget constraint: insufficient funds for additional testing"
|
884 |
|
885 |
return action
|
886 |
|
@@ -892,6 +1377,145 @@ class MaiDxOrchestrator:
|
|
892 |
content="Could you please provide more information about the patient's current condition?",
|
893 |
reasoning=f"Fallback due to panel deliberation error: {str(e)}",
|
894 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
|
896 |
def _interact_with_gatekeeper(
|
897 |
self, action: Action, full_case_details: str
|
@@ -917,7 +1541,7 @@ class MaiDxOrchestrator:
|
|
917 |
{request}
|
918 |
"""
|
919 |
|
920 |
-
response =
|
921 |
return response
|
922 |
|
923 |
def _judge_diagnosis(
|
@@ -929,19 +1553,82 @@ class MaiDxOrchestrator:
|
|
929 |
Please evaluate the following diagnosis.
|
930 |
Ground Truth: "{ground_truth}"
|
931 |
Candidate Diagnosis: "{candidate_diagnosis}"
|
|
|
|
|
|
|
|
|
932 |
"""
|
933 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
|
935 |
-
#
|
936 |
try:
|
937 |
-
score
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
942 |
score = 0.0
|
943 |
-
reasoning = "Could not parse judge's response
|
944 |
|
|
|
945 |
return {"score": score, "reasoning": reasoning}
|
946 |
|
947 |
def run(
|
@@ -951,7 +1638,7 @@ class MaiDxOrchestrator:
|
|
951 |
ground_truth_diagnosis: str,
|
952 |
) -> DiagnosisResult:
|
953 |
"""
|
954 |
-
Executes the full sequential diagnostic process.
|
955 |
|
956 |
Args:
|
957 |
initial_case_info (str): The initial abstract of the case.
|
@@ -962,13 +1649,22 @@ class MaiDxOrchestrator:
|
|
962 |
DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
|
963 |
"""
|
964 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
965 |
self.conversation.add(
|
966 |
"Gatekeeper",
|
967 |
f"Initial Case Information: {initial_case_info}",
|
968 |
)
|
|
|
969 |
|
970 |
-
# Add initial physician visit cost
|
971 |
-
self.cumulative_cost += self.physician_visit_cost
|
972 |
logger.info(
|
973 |
f"Initial physician visit cost: ${self.physician_visit_cost}"
|
974 |
)
|
@@ -978,16 +1674,18 @@ class MaiDxOrchestrator:
|
|
978 |
|
979 |
for i in range(self.max_iterations):
|
980 |
iteration_count = i + 1
|
|
|
|
|
981 |
logger.info(
|
982 |
f"--- Starting Diagnostic Loop {iteration_count}/{self.max_iterations} ---"
|
983 |
)
|
984 |
logger.info(
|
985 |
-
f"Current cost: ${
|
986 |
)
|
987 |
|
988 |
try:
|
989 |
-
# Panel deliberates to decide on the next action
|
990 |
-
action = self._run_panel_deliberation()
|
991 |
logger.info(
|
992 |
f"βοΈ Panel decision: {action.action_type.upper()} -> {action.content}"
|
993 |
)
|
@@ -995,6 +1693,9 @@ class MaiDxOrchestrator:
|
|
995 |
f"π Medical reasoning: {action.reasoning}"
|
996 |
)
|
997 |
|
|
|
|
|
|
|
998 |
if action.action_type == "diagnose":
|
999 |
final_diagnosis = action.content
|
1000 |
logger.info(
|
@@ -1002,7 +1703,7 @@ class MaiDxOrchestrator:
|
|
1002 |
)
|
1003 |
break
|
1004 |
|
1005 |
-
# Handle mode-specific constraints
|
1006 |
if (
|
1007 |
self.mode == "question_only"
|
1008 |
and action.action_type == "test"
|
@@ -1021,7 +1722,7 @@ class MaiDxOrchestrator:
|
|
1021 |
action.content
|
1022 |
)
|
1023 |
if (
|
1024 |
-
|
1025 |
> self.initial_budget
|
1026 |
):
|
1027 |
logger.warning(
|
@@ -1034,16 +1735,21 @@ class MaiDxOrchestrator:
|
|
1034 |
action, full_case_details
|
1035 |
)
|
1036 |
self.conversation.add("Gatekeeper", response)
|
|
|
1037 |
|
1038 |
-
# Update costs based on action type
|
1039 |
if action.action_type == "test":
|
1040 |
test_cost = self._estimate_cost(action.content)
|
1041 |
-
|
|
|
|
|
|
|
1042 |
logger.info(f"Tests ordered: {action.content}")
|
1043 |
logger.info(
|
1044 |
-
f"Test cost: ${test_cost:,} | Cumulative cost: ${
|
1045 |
)
|
1046 |
elif action.action_type == "ask":
|
|
|
1047 |
# Questions are part of the same visit until tests are ordered
|
1048 |
logger.info(f"Questions asked: {action.content}")
|
1049 |
logger.info(
|
@@ -1053,17 +1759,13 @@ class MaiDxOrchestrator:
|
|
1053 |
# Check budget constraints for budgeted mode
|
1054 |
if (
|
1055 |
self.mode == "budgeted"
|
1056 |
-
and
|
1057 |
):
|
1058 |
logger.warning(
|
1059 |
"Budget limit reached. Forcing final diagnosis."
|
1060 |
)
|
1061 |
-
# Use current
|
1062 |
-
final_diagnosis = (
|
1063 |
-
self.differential_diagnosis.split("\n")[0]
|
1064 |
-
if "\n" in self.differential_diagnosis
|
1065 |
-
else "Diagnosis not reached within budget constraints."
|
1066 |
-
)
|
1067 |
break
|
1068 |
|
1069 |
except Exception as e:
|
@@ -1075,11 +1777,9 @@ class MaiDxOrchestrator:
|
|
1075 |
|
1076 |
else:
|
1077 |
# Max iterations reached without diagnosis
|
1078 |
-
final_diagnosis = (
|
1079 |
-
|
1080 |
-
|
1081 |
-
else "Diagnosis not reached within maximum iterations."
|
1082 |
-
)
|
1083 |
logger.warning(
|
1084 |
f"Max iterations ({self.max_iterations}) reached. Using best available diagnosis."
|
1085 |
)
|
@@ -1115,7 +1815,7 @@ class MaiDxOrchestrator:
|
|
1115 |
ground_truth=ground_truth_diagnosis,
|
1116 |
accuracy_score=judgement["score"],
|
1117 |
accuracy_reasoning=judgement["reasoning"],
|
1118 |
-
total_cost=
|
1119 |
iterations=iteration_count,
|
1120 |
conversation_history=self.conversation.get_str(),
|
1121 |
)
|
@@ -1124,7 +1824,7 @@ class MaiDxOrchestrator:
|
|
1124 |
logger.info(f" Final diagnosis: {final_diagnosis}")
|
1125 |
logger.info(f" Ground truth: {ground_truth_diagnosis}")
|
1126 |
logger.info(f" Accuracy score: {judgement['score']}/5.0")
|
1127 |
-
logger.info(f" Total cost: ${
|
1128 |
logger.info(f" Iterations: {iteration_count}")
|
1129 |
|
1130 |
return result
|
@@ -1271,7 +1971,10 @@ class MaiDxOrchestrator:
|
|
1271 |
print_on=True, # Enable printing for aggregator agent
|
1272 |
)
|
1273 |
|
1274 |
-
|
|
|
|
|
|
|
1275 |
|
1276 |
except Exception as e:
|
1277 |
logger.error(f"Error in ensemble aggregation: {e}")
|
@@ -1309,7 +2012,7 @@ class MaiDxOrchestrator:
|
|
1309 |
"mode": "budgeted",
|
1310 |
"max_iterations": 10,
|
1311 |
"enable_budget_tracking": True,
|
1312 |
-
"initial_budget": kwargs.get("budget", 5000),
|
1313 |
},
|
1314 |
"no_budget": {
|
1315 |
"mode": "no_budget",
|
@@ -1330,9 +2033,290 @@ class MaiDxOrchestrator:
|
|
1330 |
|
1331 |
config = variant_configs[variant]
|
1332 |
config.update(kwargs) # Allow overrides
|
|
|
|
|
|
|
1333 |
|
1334 |
return cls(**config)
|
1335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1336 |
|
1337 |
def run_mai_dxo_demo(
|
1338 |
case_info: str = None,
|
@@ -1387,11 +2371,14 @@ def run_mai_dxo_demo(
|
|
1387 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1388 |
variant,
|
1389 |
budget=3000,
|
1390 |
-
model_name="
|
|
|
1391 |
)
|
1392 |
else:
|
1393 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1394 |
-
variant,
|
|
|
|
|
1395 |
)
|
1396 |
|
1397 |
result = orchestrator.run(
|
@@ -1466,14 +2453,14 @@ def run_mai_dxo_demo(
|
|
1466 |
# orchestrator = MaiDxOrchestrator.create_variant(
|
1467 |
# variant_name,
|
1468 |
# budget=3000,
|
1469 |
-
# model_name="gpt-4.1",
|
1470 |
-
# max_iterations=
|
1471 |
# )
|
1472 |
# else:
|
1473 |
# orchestrator = MaiDxOrchestrator.create_variant(
|
1474 |
# variant_name,
|
1475 |
-
# model_name="gpt-4.1",
|
1476 |
-
# max_iterations=
|
1477 |
# )
|
1478 |
|
1479 |
# # Run the diagnostic process
|
@@ -1504,7 +2491,7 @@ def run_mai_dxo_demo(
|
|
1504 |
|
1505 |
# ensemble_orchestrator = MaiDxOrchestrator.create_variant(
|
1506 |
# "ensemble",
|
1507 |
-
# model_name="gpt-4.1",
|
1508 |
# max_iterations=3, # Shorter iterations for ensemble
|
1509 |
# )
|
1510 |
|
|
|
18 |
|
19 |
Example Usage:
|
20 |
# Standard MAI-DxO usage
|
21 |
+
orchestrator = MaiDxOrchestrator(model_name="gpt-4o")
|
22 |
result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
|
23 |
|
24 |
# Budget-constrained variant
|
|
|
33 |
import json
|
34 |
import sys
|
35 |
import time
|
36 |
+
from dataclasses import dataclass, field
|
37 |
from enum import Enum
|
38 |
+
from typing import Any, Dict, List, Union, Literal, Optional
|
39 |
|
40 |
from loguru import logger
|
41 |
+
from pydantic import BaseModel, Field, ValidationError
|
42 |
from swarms import Agent, Conversation
|
43 |
+
from dotenv import load_dotenv
|
44 |
+
|
45 |
+
load_dotenv()
|
46 |
|
47 |
# Configure Loguru with beautiful formatting and features
|
48 |
logger.remove() # Remove default handler
|
|
|
94 |
JUDGE = "Judge"
|
95 |
|
96 |
|
97 |
+
@dataclass
|
98 |
+
class CaseState:
|
99 |
+
"""Structured state management for diagnostic process - addresses Category 2.1"""
|
100 |
+
initial_vignette: str
|
101 |
+
evidence_log: List[str] = field(default_factory=list)
|
102 |
+
differential_diagnosis: Dict[str, float] = field(default_factory=dict)
|
103 |
+
tests_performed: List[str] = field(default_factory=list)
|
104 |
+
questions_asked: List[str] = field(default_factory=list)
|
105 |
+
cumulative_cost: int = 0
|
106 |
+
iteration: int = 0
|
107 |
+
last_actions: List['Action'] = field(default_factory=list) # For stagnation detection
|
108 |
+
|
109 |
+
def add_evidence(self, evidence: str):
|
110 |
+
"""Add new evidence to the case"""
|
111 |
+
self.evidence_log.append(f"[Turn {self.iteration}] {evidence}")
|
112 |
+
|
113 |
+
def update_differential(self, diagnosis_dict: Dict[str, float]):
|
114 |
+
"""Update differential diagnosis probabilities"""
|
115 |
+
self.differential_diagnosis.update(diagnosis_dict)
|
116 |
+
|
117 |
+
def add_test(self, test_name: str):
|
118 |
+
"""Record a test that was performed"""
|
119 |
+
self.tests_performed.append(test_name)
|
120 |
+
|
121 |
+
def add_question(self, question: str):
|
122 |
+
"""Record a question that was asked"""
|
123 |
+
self.questions_asked.append(question)
|
124 |
+
|
125 |
+
def is_stagnating(self, new_action: 'Action') -> bool:
|
126 |
+
"""Detect if the system is stuck in a loop - addresses Category 1.2"""
|
127 |
+
if len(self.last_actions) < 2:
|
128 |
+
return False
|
129 |
+
|
130 |
+
# Check if the new action is identical to recent ones
|
131 |
+
for last_action in self.last_actions[-2:]:
|
132 |
+
if (last_action.action_type == new_action.action_type and
|
133 |
+
last_action.content == new_action.content):
|
134 |
+
return True
|
135 |
+
return False
|
136 |
+
|
137 |
+
def add_action(self, action: 'Action'):
|
138 |
+
"""Add action to history and maintain sliding window"""
|
139 |
+
self.last_actions.append(action)
|
140 |
+
if len(self.last_actions) > 3: # Keep only last 3 actions
|
141 |
+
self.last_actions.pop(0)
|
142 |
+
|
143 |
+
def get_max_confidence(self) -> float:
|
144 |
+
"""Get the maximum confidence from differential diagnosis"""
|
145 |
+
if not self.differential_diagnosis:
|
146 |
+
return 0.0
|
147 |
+
return max(self.differential_diagnosis.values())
|
148 |
+
|
149 |
+
def get_leading_diagnosis(self) -> str:
|
150 |
+
"""Get the diagnosis with highest confidence"""
|
151 |
+
if not self.differential_diagnosis:
|
152 |
+
return "No diagnosis formulated"
|
153 |
+
return max(self.differential_diagnosis.items(), key=lambda x: x[1])[0]
|
154 |
+
|
155 |
+
def summarize_evidence(self) -> str:
|
156 |
+
"""Create a concise summary of evidence for token efficiency"""
|
157 |
+
if len(self.evidence_log) <= 5:
|
158 |
+
return "\n".join(self.evidence_log)
|
159 |
+
|
160 |
+
# Keep first 2 and last 3 entries, summarize middle
|
161 |
+
summary_parts = []
|
162 |
+
summary_parts.extend(self.evidence_log[:2])
|
163 |
+
|
164 |
+
if len(self.evidence_log) > 5:
|
165 |
+
middle_count = len(self.evidence_log) - 5
|
166 |
+
summary_parts.append(f"[... {middle_count} additional findings ...]")
|
167 |
+
|
168 |
+
summary_parts.extend(self.evidence_log[-3:])
|
169 |
+
return "\n".join(summary_parts)
|
170 |
+
|
171 |
+
|
172 |
+
@dataclass
|
173 |
+
class DeliberationState:
|
174 |
+
"""Structured state for panel deliberation - addresses Category 1.1"""
|
175 |
+
hypothesis_analysis: str = ""
|
176 |
+
test_chooser_analysis: str = ""
|
177 |
+
challenger_analysis: str = ""
|
178 |
+
stewardship_analysis: str = ""
|
179 |
+
checklist_analysis: str = ""
|
180 |
+
situational_context: str = ""
|
181 |
+
stagnation_detected: bool = False
|
182 |
+
retry_count: int = 0
|
183 |
+
|
184 |
+
def to_consensus_prompt(self) -> str:
|
185 |
+
"""Generate a structured prompt for the consensus coordinator - no truncation, let agent self-regulate"""
|
186 |
+
|
187 |
+
prompt = f"""
|
188 |
+
You are the Consensus Coordinator. Here is the panel's analysis:
|
189 |
+
|
190 |
+
**Differential Diagnosis (Dr. Hypothesis):**
|
191 |
+
{self.hypothesis_analysis or 'Not yet formulated'}
|
192 |
+
|
193 |
+
**Test Recommendations (Dr. Test-Chooser):**
|
194 |
+
{self.test_chooser_analysis or 'None provided'}
|
195 |
+
|
196 |
+
**Critical Challenges (Dr. Challenger):**
|
197 |
+
{self.challenger_analysis or 'None identified'}
|
198 |
+
|
199 |
+
**Cost Assessment (Dr. Stewardship):**
|
200 |
+
{self.stewardship_analysis or 'Not evaluated'}
|
201 |
+
|
202 |
+
**Quality Control (Dr. Checklist):**
|
203 |
+
{self.checklist_analysis or 'No issues noted'}
|
204 |
+
"""
|
205 |
+
|
206 |
+
if self.stagnation_detected:
|
207 |
+
prompt += "\n**STAGNATION DETECTED** - The panel is repeating actions. You MUST make a decisive choice or provide final diagnosis."
|
208 |
+
|
209 |
+
if self.situational_context:
|
210 |
+
prompt += f"\n**Situational Context:** {self.situational_context}"
|
211 |
+
|
212 |
+
prompt += "\n\nBased on this comprehensive panel input, use the make_consensus_decision function to provide your structured action."
|
213 |
+
return prompt
|
214 |
+
|
215 |
+
|
216 |
@dataclass
|
217 |
class DiagnosisResult:
|
218 |
"""Stores the final result of a diagnostic session."""
|
|
|
241 |
)
|
242 |
|
243 |
|
244 |
+
# ------------------------------------------------------------------
|
245 |
+
# Strongly-typed models for function-calling arguments (type safety)
|
246 |
+
# ------------------------------------------------------------------
|
247 |
+
|
248 |
+
|
249 |
+
class ConsensusArguments(BaseModel):
|
250 |
+
"""Typed model for the `make_consensus_decision` function call."""
|
251 |
+
|
252 |
+
action_type: Literal["ask", "test", "diagnose"]
|
253 |
+
content: Union[str, List[str]]
|
254 |
+
reasoning: str
|
255 |
+
|
256 |
+
|
257 |
+
class DifferentialDiagnosisItem(BaseModel):
|
258 |
+
"""Single differential diagnosis item returned by Dr. Hypothesis."""
|
259 |
+
|
260 |
+
diagnosis: str
|
261 |
+
probability: float
|
262 |
+
rationale: str
|
263 |
+
|
264 |
+
|
265 |
+
class HypothesisArguments(BaseModel):
|
266 |
+
"""Typed model for the `update_differential_diagnosis` function call."""
|
267 |
+
|
268 |
+
summary: str
|
269 |
+
differential_diagnoses: List[DifferentialDiagnosisItem]
|
270 |
+
key_evidence: str
|
271 |
+
contradictory_evidence: Optional[str] = None
|
272 |
+
|
273 |
+
|
274 |
# --- Main Orchestrator Class ---
|
275 |
|
276 |
|
|
|
279 |
Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
|
280 |
This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
|
281 |
evaluates the final diagnosis, and tracks costs.
|
282 |
+
|
283 |
+
Enhanced with structured deliberation and proper state management as per research paper.
|
284 |
"""
|
285 |
|
286 |
def __init__(
|
287 |
self,
|
288 |
+
model_name: str = "gpt-4o-mini", # Fixed: Use valid GPT-4 Turbo model name
|
289 |
max_iterations: int = 10,
|
290 |
initial_budget: int = 10000,
|
291 |
mode: str = "no_budget", # "instant", "question_only", "budgeted", "no_budget", "ensemble"
|
292 |
physician_visit_cost: int = 300,
|
293 |
enable_budget_tracking: bool = False,
|
294 |
+
request_delay: float = 8.0, # seconds to wait between model calls to mitigate rate-limits
|
295 |
):
|
296 |
"""
|
297 |
+
Initializes the MAI-DxO system with improved architecture.
|
298 |
|
299 |
Args:
|
300 |
model_name (str): The language model to be used by all agents.
|
|
|
303 |
mode (str): The operational mode of MAI-DxO.
|
304 |
physician_visit_cost (int): Cost per physician visit.
|
305 |
enable_budget_tracking (bool): Whether to enable budget tracking.
|
306 |
+
request_delay (float): Seconds to wait between model calls to mitigate rate-limits.
|
307 |
"""
|
308 |
self.model_name = model_name
|
309 |
self.max_iterations = max_iterations
|
|
|
312 |
self.physician_visit_cost = physician_visit_cost
|
313 |
self.enable_budget_tracking = enable_budget_tracking
|
314 |
|
315 |
+
# Throttle settings to avoid OpenAI TPM rate-limits
|
316 |
+
self.request_delay = max(request_delay, 0)
|
317 |
+
|
318 |
+
# Token management
|
319 |
+
self.max_total_tokens_per_request = 25000 # Safety margin below 30k limit
|
320 |
+
|
321 |
self.cumulative_cost = 0
|
322 |
self.differential_diagnosis = "Not yet formulated."
|
323 |
self.conversation = Conversation(
|
324 |
time_enabled=True, autosave=False, save_enabled=False
|
325 |
)
|
326 |
+
|
327 |
+
# Initialize case state for structured state management
|
328 |
+
self.case_state = None
|
329 |
|
330 |
# Enhanced cost model based on the paper's methodology
|
331 |
self.test_cost_db = {
|
|
|
363 |
f"π₯ MAI Diagnostic Orchestrator initialized successfully in '{mode}' mode with budget ${initial_budget:,}"
|
364 |
)
|
365 |
|
366 |
+
def _get_agent_max_tokens(self, role: AgentRole) -> int:
|
367 |
+
"""Get max_tokens for each agent based on their role - agents will self-regulate based on token guidance"""
|
368 |
+
token_limits = {
|
369 |
+
# Reasonable limits - agents will adjust their verbosity based on token guidance
|
370 |
+
AgentRole.HYPOTHESIS: 1200, # Function calling keeps this structured, but allow room for quality
|
371 |
+
AgentRole.TEST_CHOOSER: 800, # Need space for test rationale
|
372 |
+
AgentRole.CHALLENGER: 800, # Need space for critical analysis
|
373 |
+
AgentRole.STEWARDSHIP: 600,
|
374 |
+
AgentRole.CHECKLIST: 400,
|
375 |
+
AgentRole.CONSENSUS: 500, # Function calling is efficient
|
376 |
+
AgentRole.GATEKEEPER: 1000, # Needs to provide detailed clinical findings
|
377 |
+
AgentRole.JUDGE: 700,
|
378 |
+
}
|
379 |
+
return token_limits.get(role, 600)
|
380 |
+
|
381 |
+
def _estimate_tokens(self, text: str) -> int:
|
382 |
+
"""Rough token estimation (1 token β 4 characters for English)"""
|
383 |
+
return len(text) // 4
|
384 |
+
|
385 |
+
def _generate_token_guidance(self, input_tokens: int, max_output_tokens: int, total_tokens: int, agent_role: AgentRole) -> str:
|
386 |
+
"""Generate dynamic token guidance for agents to self-regulate their responses"""
|
387 |
+
|
388 |
+
# Determine urgency level based on token usage
|
389 |
+
if total_tokens > self.max_total_tokens_per_request:
|
390 |
+
urgency = "CRITICAL"
|
391 |
+
strategy = "Be extremely concise. Prioritize only the most essential information."
|
392 |
+
elif total_tokens > self.max_total_tokens_per_request * 0.8:
|
393 |
+
urgency = "HIGH"
|
394 |
+
strategy = "Be concise and focus on key points. Avoid elaborate explanations."
|
395 |
+
elif total_tokens > self.max_total_tokens_per_request * 0.6:
|
396 |
+
urgency = "MODERATE"
|
397 |
+
strategy = "Be reasonably concise while maintaining necessary detail."
|
398 |
+
else:
|
399 |
+
urgency = "LOW"
|
400 |
+
strategy = "You can provide detailed analysis within your allocated tokens."
|
401 |
+
|
402 |
+
# Role-specific guidance
|
403 |
+
role_specific_guidance = {
|
404 |
+
AgentRole.HYPOTHESIS: "Focus on top 2-3 diagnoses with probabilities. Prioritize summary over detailed pathophysiology.",
|
405 |
+
AgentRole.TEST_CHOOSER: "Recommend 1-2 highest-yield tests. Focus on which hypotheses they'll help differentiate.",
|
406 |
+
AgentRole.CHALLENGER: "Identify 1-2 most critical biases or alternative diagnoses. Be direct and specific.",
|
407 |
+
AgentRole.STEWARDSHIP: "Focus on cost-effectiveness assessment. Recommend cheaper alternatives where applicable.",
|
408 |
+
AgentRole.CHECKLIST: "Provide concise quality check. Flag critical issues only.",
|
409 |
+
AgentRole.CONSENSUS: "Function calling enforces structure. Focus on clear reasoning.",
|
410 |
+
AgentRole.GATEKEEPER: "Provide specific clinical findings. Be factual and complete but not verbose.",
|
411 |
+
AgentRole.JUDGE: "Provide score and focused justification. Be systematic but concise."
|
412 |
+
}.get(agent_role, "Be concise and focused.")
|
413 |
+
|
414 |
+
guidance = f"""
|
415 |
+
[TOKEN MANAGEMENT - {urgency} PRIORITY]
|
416 |
+
Input: {input_tokens} tokens | Your Output Limit: {max_output_tokens} tokens | Total: {total_tokens} tokens
|
417 |
+
Strategy: {strategy}
|
418 |
+
Role Focus: {role_specific_guidance}
|
419 |
+
|
420 |
+
IMPORTANT: Adjust your response length and detail level based on this guidance. Prioritize the most critical information for your role.
|
421 |
+
"""
|
422 |
+
|
423 |
+
return guidance
|
424 |
+
|
425 |
def _init_agents(self) -> None:
|
426 |
"""Initializes all required agents with their specific roles and prompts."""
|
427 |
+
|
428 |
+
# Define the structured output tool for consensus decisions
|
429 |
+
consensus_tool = {
|
430 |
+
"type": "function",
|
431 |
+
"function": {
|
432 |
+
"name": "make_consensus_decision",
|
433 |
+
"description": "Make a structured consensus decision for the next diagnostic action",
|
434 |
+
"parameters": {
|
435 |
+
"type": "object",
|
436 |
+
"properties": {
|
437 |
+
"action_type": {
|
438 |
+
"type": "string",
|
439 |
+
"enum": ["ask", "test", "diagnose"],
|
440 |
+
"description": "The type of action to perform"
|
441 |
+
},
|
442 |
+
"content": {
|
443 |
+
"type": "string",
|
444 |
+
"description": "The specific content of the action (question, test name, or diagnosis)"
|
445 |
+
},
|
446 |
+
"reasoning": {
|
447 |
+
"type": "string",
|
448 |
+
"description": "The detailed reasoning behind this decision, synthesizing panel input"
|
449 |
+
}
|
450 |
+
},
|
451 |
+
"required": ["action_type", "content", "reasoning"]
|
452 |
+
}
|
453 |
+
}
|
454 |
+
}
|
455 |
+
|
456 |
+
# Define structured output tool for differential diagnosis
|
457 |
+
hypothesis_tool = {
|
458 |
+
"type": "function",
|
459 |
+
"function": {
|
460 |
+
"name": "update_differential_diagnosis",
|
461 |
+
"description": "Update the differential diagnosis with structured probabilities and reasoning",
|
462 |
+
"parameters": {
|
463 |
+
"type": "object",
|
464 |
+
"properties": {
|
465 |
+
"summary": {
|
466 |
+
"type": "string",
|
467 |
+
"description": "One-sentence summary of primary diagnostic conclusion and confidence"
|
468 |
+
},
|
469 |
+
"differential_diagnoses": {
|
470 |
+
"type": "array",
|
471 |
+
"items": {
|
472 |
+
"type": "object",
|
473 |
+
"properties": {
|
474 |
+
"diagnosis": {"type": "string", "description": "The diagnosis name"},
|
475 |
+
"probability": {"type": "number", "minimum": 0, "maximum": 1, "description": "Probability as decimal (0.0-1.0)"},
|
476 |
+
"rationale": {"type": "string", "description": "Brief rationale for this diagnosis"}
|
477 |
+
},
|
478 |
+
"required": ["diagnosis", "probability", "rationale"]
|
479 |
+
},
|
480 |
+
"minItems": 2,
|
481 |
+
"maxItems": 5,
|
482 |
+
"description": "Top 2-5 differential diagnoses with probabilities"
|
483 |
+
},
|
484 |
+
"key_evidence": {
|
485 |
+
"type": "string",
|
486 |
+
"description": "Key supporting evidence for leading hypotheses"
|
487 |
+
},
|
488 |
+
"contradictory_evidence": {
|
489 |
+
"type": "string",
|
490 |
+
"description": "Critical contradictory evidence that must be addressed"
|
491 |
+
}
|
492 |
+
},
|
493 |
+
"required": ["summary", "differential_diagnoses", "key_evidence"]
|
494 |
+
}
|
495 |
+
}
|
496 |
}
|
497 |
+
|
498 |
+
self.agents = {}
|
499 |
+
for role in AgentRole:
|
500 |
+
if role == AgentRole.CONSENSUS:
|
501 |
+
# Use function calling for consensus agent to ensure structured output
|
502 |
+
self.agents[role] = Agent(
|
503 |
+
agent_name=role.value,
|
504 |
+
system_prompt=self._get_prompt_for_role(role),
|
505 |
+
model_name=self.model_name,
|
506 |
+
max_loops=1,
|
507 |
+
tools_list_dictionary=[consensus_tool], # swarms expects tools_list_dictionary
|
508 |
+
tool_choice="auto", # Let the model choose to use the tool
|
509 |
+
print_on=True,
|
510 |
+
max_tokens=self._get_agent_max_tokens(role),
|
511 |
+
)
|
512 |
+
elif role == AgentRole.HYPOTHESIS:
|
513 |
+
# Use function calling for hypothesis agent to ensure structured differential
|
514 |
+
self.agents[role] = Agent(
|
515 |
+
agent_name=role.value,
|
516 |
+
system_prompt=self._get_prompt_for_role(role),
|
517 |
+
model_name=self.model_name,
|
518 |
+
max_loops=1,
|
519 |
+
tools_list_dictionary=[hypothesis_tool],
|
520 |
+
tool_choice="auto",
|
521 |
+
print_on=True,
|
522 |
+
max_tokens=self._get_agent_max_tokens(role),
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
# Regular agents without function calling
|
526 |
+
self.agents[role] = Agent(
|
527 |
+
agent_name=role.value,
|
528 |
+
system_prompt=self._get_prompt_for_role(role),
|
529 |
+
model_name=self.model_name,
|
530 |
+
max_loops=1,
|
531 |
+
output_type="str",
|
532 |
+
print_on=True,
|
533 |
+
max_tokens=self._get_agent_max_tokens(role),
|
534 |
+
)
|
535 |
+
|
536 |
logger.info(
|
537 |
f"π₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
|
538 |
)
|
539 |
|
540 |
+
def _get_dynamic_context(self, role: AgentRole, case_state: CaseState) -> str:
|
541 |
+
"""Generate dynamic context for agents based on current situation - addresses Category 4.2"""
|
542 |
+
remaining_budget = self.initial_budget - case_state.cumulative_cost
|
543 |
+
|
544 |
+
# Calculate confidence from differential diagnosis
|
545 |
+
max_confidence = max(case_state.differential_diagnosis.values()) if case_state.differential_diagnosis else 0
|
546 |
+
|
547 |
+
context = ""
|
548 |
+
|
549 |
+
if role == AgentRole.STEWARDSHIP and remaining_budget < 1000:
|
550 |
+
context = f"""
|
551 |
+
**SITUATIONAL CONTEXT: URGENT**
|
552 |
+
The remaining budget is critically low (${remaining_budget}). All recommendations must be focused on maximum cost-effectiveness. Veto any non-essential or high-cost tests.
|
553 |
+
"""
|
554 |
+
|
555 |
+
elif role == AgentRole.HYPOTHESIS and max_confidence > 0.75:
|
556 |
+
context = f"""
|
557 |
+
**SITUATIONAL CONTEXT: FINAL STAGES**
|
558 |
+
The panel is converging on a diagnosis (current max confidence: {max_confidence:.0%}). Your primary role now is to confirm the leading hypothesis or state what single piece of evidence is needed to reach >85% confidence.
|
559 |
+
"""
|
560 |
+
|
561 |
+
elif role == AgentRole.CONSENSUS and case_state.iteration > 5:
|
562 |
+
context = f"""
|
563 |
+
**SITUATIONAL CONTEXT: EXTENDED CASE**
|
564 |
+
This case has gone through {case_state.iteration} iterations. Focus on decisive actions that will lead to a definitive diagnosis rather than additional exploratory steps.
|
565 |
+
"""
|
566 |
+
|
567 |
+
return context
|
568 |
+
|
569 |
+
def _get_prompt_for_role(self, role: AgentRole, case_state: CaseState = None) -> str:
|
570 |
+
"""Returns the system prompt for a given agent role with dynamic context."""
|
571 |
+
|
572 |
+
# Add dynamic context if case_state is provided
|
573 |
+
dynamic_context = ""
|
574 |
+
if case_state:
|
575 |
+
dynamic_context = self._get_dynamic_context(role, case_state)
|
576 |
+
|
577 |
+
# --- Compact, token-efficient prompts ---
|
578 |
+
base_prompts = {
|
579 |
+
AgentRole.HYPOTHESIS: f"""{dynamic_context}
|
580 |
+
|
581 |
+
MANDATE: Keep an up-to-date, probability-ranked differential.
|
582 |
+
|
583 |
+
DIRECTIVES:
|
584 |
+
1. Return 2-5 diagnoses (prob 0-1) with 1-line rationale.
|
585 |
+
2. List key supporting & contradictory evidence.
|
586 |
+
|
587 |
+
You MUST call update_differential_diagnosis().""",
|
588 |
+
|
589 |
+
AgentRole.TEST_CHOOSER: f"""{dynamic_context}
|
590 |
+
|
591 |
+
MANDATE: Pick the highest-yield tests.
|
592 |
+
|
593 |
+
DIRECTIVES:
|
594 |
+
1. Suggest β€3 tests that best separate current diagnoses.
|
595 |
+
2. Note target hypothesis & info-gain vs cost.
|
596 |
+
|
597 |
+
Limit: focus on top 1-2 critical points.""",
|
598 |
+
|
599 |
+
AgentRole.CHALLENGER: f"""{dynamic_context}
|
600 |
+
|
601 |
+
MANDATE: Expose the biggest flaw or bias.
|
602 |
+
|
603 |
+
DIRECTIVES:
|
604 |
+
1. Name the key bias/contradiction.
|
605 |
+
2. Propose an alternate diagnosis or falsifying test.
|
606 |
+
|
607 |
+
Reply concisely (top 2 issues).""",
|
608 |
+
|
609 |
+
AgentRole.STEWARDSHIP: f"""{dynamic_context}
|
610 |
+
|
611 |
+
MANDATE: Ensure cost-effective care.
|
612 |
+
|
613 |
+
DIRECTIVES:
|
614 |
+
1. Rate proposed tests (High/Mod/Low value).
|
615 |
+
2. Suggest cheaper equivalents where possible.
|
616 |
+
|
617 |
+
Be brief; highlight savings.""",
|
618 |
+
|
619 |
+
AgentRole.CHECKLIST: f"""{dynamic_context}
|
620 |
+
|
621 |
+
MANDATE: Guarantee quality & consistency.
|
622 |
+
|
623 |
+
DIRECTIVES:
|
624 |
+
1. Flag invalid tests or logic gaps.
|
625 |
+
2. Note safety concerns.
|
626 |
+
|
627 |
+
Return bullet list of critical items.""",
|
628 |
+
|
629 |
+
AgentRole.CONSENSUS: f"""{dynamic_context}
|
630 |
+
|
631 |
+
MANDATE: Decide the next action.
|
632 |
+
|
633 |
+
DECISION RULES:
|
634 |
+
1. If confidence >85% & no major objection β diagnose.
|
635 |
+
2. Else address Challenger's top concern.
|
636 |
+
3. Else order highest info-gain (cheapest) test.
|
637 |
+
4. Else ask the most informative question.
|
638 |
+
|
639 |
+
You MUST call make_consensus_decision().""",
|
640 |
+
}
|
641 |
+
|
642 |
+
# Use existing prompts for other roles, just add dynamic context
|
643 |
+
if role not in base_prompts:
|
644 |
+
return dynamic_context + self._get_original_prompt_for_role(role)
|
645 |
+
|
646 |
+
return base_prompts[role]
|
647 |
+
|
648 |
+
def _get_original_prompt_for_role(self, role: AgentRole) -> str:
|
649 |
+
"""Returns original system prompts for roles not yet updated"""
|
650 |
prompts = {
|
651 |
AgentRole.HYPOTHESIS: (
|
652 |
"""
|
|
|
953 |
}
|
954 |
return prompts[role]
|
955 |
|
956 |
+
def _parse_json_response(self, response: str, retry_count: int = 0) -> Dict[str, Any]:
|
957 |
+
"""Safely parses a JSON string with retry logic - addresses Category 3.2"""
|
958 |
try:
|
959 |
+
# Handle agent response wrapper - extract actual content
|
960 |
+
if isinstance(response, dict):
|
961 |
+
# Handle swarms Agent response format
|
962 |
+
if 'role' in response and 'content' in response:
|
963 |
+
response = response['content']
|
964 |
+
elif 'content' in response:
|
965 |
+
response = response['content']
|
966 |
+
else:
|
967 |
+
# Try to extract any string value from dict
|
968 |
+
response = str(response)
|
969 |
+
elif hasattr(response, 'content'):
|
970 |
+
response = response.content
|
971 |
+
elif not isinstance(response, str):
|
972 |
+
# Convert to string if it's some other type
|
973 |
+
response = str(response)
|
974 |
+
|
975 |
# Extract the actual response content from the agent response
|
976 |
if isinstance(response, str):
|
977 |
# Handle markdown-formatted JSON
|
|
|
1020 |
# Try to extract JSON from text that might contain other content
|
1021 |
import re
|
1022 |
|
1023 |
+
# Look for JSON pattern in the text - more comprehensive regex
|
1024 |
+
json_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
|
1025 |
+
matches = re.findall(json_pattern, response, re.DOTALL)
|
|
|
|
|
1026 |
|
1027 |
for match in matches:
|
1028 |
try:
|
1029 |
+
parsed = json.loads(match)
|
1030 |
+
# Validate that it has the expected action structure
|
1031 |
+
if isinstance(parsed, dict) and 'action_type' in parsed:
|
1032 |
+
return parsed
|
1033 |
except json.JSONDecodeError:
|
1034 |
continue
|
1035 |
|
1036 |
# Direct parsing attempt as fallback
|
1037 |
+
try:
|
1038 |
+
return json.loads(response)
|
1039 |
+
except json.JSONDecodeError:
|
1040 |
+
# --- Fallback Sanitization ---
|
1041 |
+
# Attempt to strip any leading table/frame characters (e.g., β, β, β°) that may wrap each line
|
1042 |
+
try:
|
1043 |
+
# Extract everything between the first '{' and last '}'
|
1044 |
+
start_curly = response.index('{')
|
1045 |
+
end_curly = response.rindex('}')
|
1046 |
+
candidate = response[start_curly:end_curly + 1]
|
1047 |
+
sanitized_lines = []
|
1048 |
+
for line in candidate.splitlines():
|
1049 |
+
# Remove common frame characters and leading whitespace
|
1050 |
+
line = line.lstrip('β|ββ°β―βββ€ ').rstrip('β|ββ°β―βββ€ ')
|
1051 |
+
sanitized_lines.append(line)
|
1052 |
+
candidate_clean = '\n'.join(sanitized_lines)
|
1053 |
+
return json.loads(candidate_clean)
|
1054 |
+
except Exception as inner_e:
|
1055 |
+
# Still failing, raise original error to trigger retry logic
|
1056 |
+
try:
|
1057 |
+
# --- Ultimate Fallback: Regex extraction ---
|
1058 |
+
import re
|
1059 |
+
atype = re.search(r'"action_type"\s*:\s*"(ask|test|diagnose)"', response, re.IGNORECASE)
|
1060 |
+
content_match = re.search(r'"content"\s*:\s*"([^"]+?)"', response, re.IGNORECASE | re.DOTALL)
|
1061 |
+
reasoning_match = re.search(r'"reasoning"\s*:\s*"([^"]+?)"', response, re.IGNORECASE | re.DOTALL)
|
1062 |
+
if atype and content_match and reasoning_match:
|
1063 |
+
return {
|
1064 |
+
"action_type": atype.group(1).lower(),
|
1065 |
+
"content": content_match.group(1).strip(),
|
1066 |
+
"reasoning": reasoning_match.group(1).strip(),
|
1067 |
+
}
|
1068 |
+
except Exception:
|
1069 |
+
pass
|
1070 |
+
raise e
|
1071 |
|
1072 |
except (
|
1073 |
json.JSONDecodeError,
|
|
|
1078 |
logger.debug(
|
1079 |
f"Response content: {response[:500]}..."
|
1080 |
) # Log first 500 chars
|
1081 |
+
|
1082 |
+
# Return the error for potential retry instead of immediately falling back
|
1083 |
+
raise e
|
1084 |
+
|
1085 |
+
def _parse_json_with_retry(self, consensus_agent: Agent, consensus_prompt: str, max_retries: int = 3) -> Dict[str, Any]:
|
1086 |
+
"""Parse JSON with retry logic for robustness - addresses Category 3.2"""
|
1087 |
+
for attempt in range(max_retries + 1):
|
1088 |
+
try:
|
1089 |
+
if attempt == 0:
|
1090 |
+
response = consensus_agent.run(consensus_prompt)
|
1091 |
+
else:
|
1092 |
+
# Retry with error feedback
|
1093 |
+
retry_prompt = f"""
|
1094 |
+
{consensus_prompt}
|
1095 |
+
|
1096 |
+
**CRITICAL: RETRY REQUIRED - ATTEMPT {attempt + 1}**
|
1097 |
+
Your previous response could not be parsed as JSON. You MUST respond with ONLY a valid JSON object in exactly this format:
|
1098 |
+
|
1099 |
+
{{
|
1100 |
+
"action_type": "ask" | "test" | "diagnose",
|
1101 |
+
"content": "your content here",
|
1102 |
+
"reasoning": "your reasoning here"
|
1103 |
+
}}
|
1104 |
+
|
1105 |
+
Do NOT include any other text, markdown formatting, or explanations. Only the raw JSON object.
|
1106 |
+
NO SYSTEM MESSAGES, NO WRAPPER FORMAT. JUST THE JSON.
|
1107 |
+
"""
|
1108 |
+
response = consensus_agent.run(retry_prompt)
|
1109 |
+
|
1110 |
+
# Handle different response types from swarms Agent
|
1111 |
+
response_text = ""
|
1112 |
+
if hasattr(response, 'content'):
|
1113 |
+
response_text = response.content
|
1114 |
+
elif isinstance(response, dict):
|
1115 |
+
# Handle swarms Agent response wrapper
|
1116 |
+
if 'role' in response and 'content' in response:
|
1117 |
+
response_text = response['content']
|
1118 |
+
elif 'content' in response:
|
1119 |
+
response_text = response['content']
|
1120 |
+
else:
|
1121 |
+
response_text = str(response)
|
1122 |
+
elif isinstance(response, str):
|
1123 |
+
response_text = response
|
1124 |
+
else:
|
1125 |
+
response_text = str(response)
|
1126 |
+
|
1127 |
+
# Log the response for debugging
|
1128 |
+
logger.debug(f"Parsing attempt {attempt + 1}, response type: {type(response)}")
|
1129 |
+
logger.debug(f"Response content preview: {str(response_text)[:200]}...")
|
1130 |
+
|
1131 |
+
return self._parse_json_response(response_text, attempt)
|
1132 |
+
|
1133 |
+
except Exception as e:
|
1134 |
+
logger.warning(f"JSON parsing attempt {attempt + 1} failed: {e}")
|
1135 |
+
if attempt == max_retries:
|
1136 |
+
# Final fallback after all retries
|
1137 |
+
logger.error("All JSON parsing attempts failed, using fallback")
|
1138 |
+
return {
|
1139 |
+
"action_type": "ask",
|
1140 |
+
"content": "Could you please clarify the next best step? The previous analysis was inconclusive.",
|
1141 |
+
"reasoning": f"Fallback due to JSON parsing error after {max_retries + 1} attempts.",
|
1142 |
+
}
|
1143 |
+
|
1144 |
+
# Should never reach here, but just in case
|
1145 |
+
return {
|
1146 |
+
"action_type": "ask",
|
1147 |
+
"content": "Please provide more information about the patient's condition.",
|
1148 |
+
"reasoning": "Unexpected fallback in JSON parsing.",
|
1149 |
+
}
|
1150 |
|
1151 |
def _estimate_cost(self, tests: Union[List[str], str]) -> int:
|
1152 |
"""Estimates the cost of diagnostic tests."""
|
|
|
1235 |
|
1236 |
return cost
|
1237 |
|
1238 |
+
def _run_panel_deliberation(self, case_state: CaseState) -> Action:
|
1239 |
+
"""Orchestrates one round of structured debate among the virtual panel - addresses Category 1.1"""
|
1240 |
logger.info(
|
1241 |
"π©Ί Virtual medical panel deliberation commenced - analyzing patient case"
|
1242 |
)
|
1243 |
logger.debug(
|
1244 |
"Panel members: Dr. Hypothesis, Dr. Test-Chooser, Dr. Challenger, Dr. Stewardship, Dr. Checklist"
|
1245 |
)
|
|
|
|
|
|
|
1246 |
|
1247 |
+
# Initialize structured deliberation state instead of conversational chaining
|
1248 |
+
deliberation_state = DeliberationState()
|
1249 |
+
|
1250 |
+
# Prepare concise case context for each agent (token-optimized)
|
1251 |
+
remaining_budget = self.initial_budget - case_state.cumulative_cost
|
1252 |
budget_status = (
|
1253 |
"EXCEEDED"
|
1254 |
if remaining_budget < 0
|
1255 |
else f"${remaining_budget:,}"
|
1256 |
)
|
1257 |
|
1258 |
+
# Full context - let agents self-regulate based on token guidance
|
1259 |
+
base_context = f"""
|
1260 |
+
=== DIAGNOSTIC CASE STATUS - ROUND {case_state.iteration} ===
|
1261 |
+
|
1262 |
+
INITIAL PRESENTATION:
|
1263 |
+
{case_state.initial_vignette}
|
1264 |
+
|
1265 |
+
EVIDENCE GATHERED:
|
1266 |
+
{case_state.summarize_evidence()}
|
1267 |
+
|
1268 |
+
CURRENT STATE:
|
1269 |
+
- Tests Performed: {', '.join(case_state.tests_performed) if case_state.tests_performed else 'None'}
|
1270 |
+
- Questions Asked: {len(case_state.questions_asked)}
|
1271 |
+
- Cumulative Cost: ${case_state.cumulative_cost:,}
|
1272 |
+
- Remaining Budget: {budget_status}
|
1273 |
+
- Mode: {self.mode}
|
1274 |
"""
|
|
|
1275 |
|
1276 |
# Check mode-specific constraints
|
1277 |
if self.mode == "instant":
|
1278 |
# For instant mode, skip deliberation and go straight to diagnosis
|
1279 |
action_dict = {
|
1280 |
"action_type": "diagnose",
|
1281 |
+
"content": case_state.get_leading_diagnosis(),
|
|
|
|
|
|
|
|
|
1282 |
"reasoning": (
|
1283 |
"Instant diagnosis mode - providing immediate assessment based on initial presentation"
|
1284 |
),
|
1285 |
}
|
1286 |
return Action(**action_dict)
|
1287 |
|
1288 |
+
# Check for stagnation before running deliberation
|
1289 |
+
stagnation_detected = False
|
1290 |
+
if len(case_state.last_actions) >= 2:
|
1291 |
+
last_action = case_state.last_actions[-1]
|
1292 |
+
stagnation_detected = case_state.is_stagnating(last_action)
|
1293 |
+
deliberation_state.stagnation_detected = stagnation_detected
|
1294 |
+
if stagnation_detected:
|
1295 |
+
logger.warning("π Stagnation detected - will force different action")
|
1296 |
|
1297 |
+
# Generate dynamic situational context for all agents
|
1298 |
+
deliberation_state.situational_context = self._generate_situational_context(case_state, remaining_budget)
|
1299 |
+
|
1300 |
+
# Run each specialist agent in parallel-like fashion with structured output
|
1301 |
+
# Each agent gets the same base context plus their role-specific dynamic prompt
|
1302 |
try:
|
1303 |
# Dr. Hypothesis - Differential diagnosis and probability assessment
|
1304 |
+
logger.info("π§ Dr. Hypothesis analyzing differential diagnosis...")
|
1305 |
+
hypothesis_prompt = self._get_prompt_for_role(AgentRole.HYPOTHESIS, case_state) + "\n\n" + base_context
|
1306 |
+
hypothesis_response = self._safe_agent_run(
|
1307 |
+
self.agents[AgentRole.HYPOTHESIS], hypothesis_prompt, agent_role=AgentRole.HYPOTHESIS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1308 |
)
|
1309 |
+
|
1310 |
+
# Update case state with new differential (supports both function calls and text)
|
1311 |
+
self._update_differential_from_hypothesis(case_state, hypothesis_response)
|
1312 |
+
|
1313 |
+
# Store the analysis for deliberation state (convert to text format for other agents)
|
1314 |
+
if hasattr(hypothesis_response, 'content'):
|
1315 |
+
deliberation_state.hypothesis_analysis = hypothesis_response.content
|
1316 |
+
else:
|
1317 |
+
deliberation_state.hypothesis_analysis = str(hypothesis_response)
|
1318 |
|
1319 |
# Dr. Test-Chooser - Information value optimization
|
1320 |
+
logger.info("π¬ Dr. Test-Chooser selecting optimal tests...")
|
1321 |
+
test_chooser_prompt = self._get_prompt_for_role(AgentRole.TEST_CHOOSER, case_state) + "\n\n" + base_context
|
1322 |
+
if self.mode == "question_only":
|
1323 |
+
test_chooser_prompt += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY recommend patient questions, not diagnostic tests."
|
1324 |
+
deliberation_state.test_chooser_analysis = self._safe_agent_run(
|
1325 |
+
self.agents[AgentRole.TEST_CHOOSER], test_chooser_prompt, agent_role=AgentRole.TEST_CHOOSER
|
|
|
|
|
|
|
1326 |
)
|
1327 |
|
1328 |
# Dr. Challenger - Bias identification and alternative hypotheses
|
1329 |
+
logger.info("π€ Dr. Challenger challenging assumptions...")
|
1330 |
+
challenger_prompt = self._get_prompt_for_role(AgentRole.CHALLENGER, case_state) + "\n\n" + base_context
|
1331 |
+
deliberation_state.challenger_analysis = self._safe_agent_run(
|
1332 |
+
self.agents[AgentRole.CHALLENGER], challenger_prompt, agent_role=AgentRole.CHALLENGER
|
|
|
|
|
|
|
|
|
|
|
1333 |
)
|
1334 |
|
1335 |
# Dr. Stewardship - Cost-effectiveness analysis
|
1336 |
+
logger.info("π° Dr. Stewardship evaluating cost-effectiveness...")
|
1337 |
+
stewardship_prompt = self._get_prompt_for_role(AgentRole.STEWARDSHIP, case_state) + "\n\n" + base_context
|
|
|
|
|
1338 |
if self.enable_budget_tracking:
|
1339 |
+
stewardship_prompt += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${case_state.cumulative_cost}, Remaining: ${remaining_budget}"
|
1340 |
+
deliberation_state.stewardship_analysis = self._safe_agent_run(
|
1341 |
+
self.agents[AgentRole.STEWARDSHIP], stewardship_prompt, agent_role=AgentRole.STEWARDSHIP
|
|
|
|
|
|
|
|
|
1342 |
)
|
1343 |
|
1344 |
# Dr. Checklist - Quality assurance
|
1345 |
+
logger.info("β
Dr. Checklist performing quality control...")
|
1346 |
+
checklist_prompt = self._get_prompt_for_role(AgentRole.CHECKLIST, case_state) + "\n\n" + base_context
|
1347 |
+
deliberation_state.checklist_analysis = self._safe_agent_run(
|
1348 |
+
self.agents[AgentRole.CHECKLIST], checklist_prompt, agent_role=AgentRole.CHECKLIST
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1349 |
)
|
|
|
1350 |
|
1351 |
+
# Consensus Coordinator - Final decision synthesis using structured state
|
1352 |
+
logger.info("π€ Consensus Coordinator synthesizing panel decision...")
|
1353 |
+
|
1354 |
+
# Generate the structured consensus prompt
|
1355 |
+
consensus_prompt = deliberation_state.to_consensus_prompt()
|
1356 |
+
|
1357 |
# Add mode-specific constraints to consensus
|
1358 |
if self.mode == "budgeted" and remaining_budget <= 0:
|
1359 |
+
consensus_prompt += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1360 |
|
1361 |
+
# Use function calling with retry logic for robust structured output
|
1362 |
+
action_dict = self._get_consensus_with_retry(consensus_prompt)
|
1363 |
|
1364 |
# Validate action based on mode constraints
|
1365 |
action = Action(**action_dict)
|
1366 |
+
|
1367 |
+
# Apply mode-specific validation and corrections
|
1368 |
+
action = self._validate_and_correct_action(action, case_state, remaining_budget)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1369 |
|
1370 |
return action
|
1371 |
|
|
|
1377 |
content="Could you please provide more information about the patient's current condition?",
|
1378 |
reasoning=f"Fallback due to panel deliberation error: {str(e)}",
|
1379 |
)
|
1380 |
+
|
1381 |
+
def _generate_situational_context(self, case_state: CaseState, remaining_budget: int) -> str:
|
1382 |
+
"""Generate dynamic situational context based on current case state - addresses Category 4.2"""
|
1383 |
+
context_parts = []
|
1384 |
+
|
1385 |
+
# Budget-related context
|
1386 |
+
if remaining_budget < 1000:
|
1387 |
+
context_parts.append(f"URGENT: Remaining budget critically low (${remaining_budget}). Focus on cost-effective actions.")
|
1388 |
+
elif remaining_budget < 2000:
|
1389 |
+
context_parts.append(f"WARNING: Budget running low (${remaining_budget}). Prioritize high-value tests.")
|
1390 |
+
|
1391 |
+
# Diagnostic confidence context
|
1392 |
+
max_confidence = case_state.get_max_confidence()
|
1393 |
+
if max_confidence > 0.85:
|
1394 |
+
context_parts.append(f"FINAL STAGES: High confidence diagnosis available ({max_confidence:.0%}). Consider definitive action.")
|
1395 |
+
elif max_confidence > 0.70:
|
1396 |
+
context_parts.append(f"CONVERGING: Moderate confidence in leading diagnosis ({max_confidence:.0%}). Focus on confirmation.")
|
1397 |
+
|
1398 |
+
# Iteration context
|
1399 |
+
if case_state.iteration > 7:
|
1400 |
+
context_parts.append(f"EXTENDED CASE: {case_state.iteration} rounds completed. Move toward decisive action.")
|
1401 |
+
elif case_state.iteration > 5:
|
1402 |
+
context_parts.append(f"PROLONGED: {case_state.iteration} rounds. Avoid further exploratory steps unless critical.")
|
1403 |
+
|
1404 |
+
# Test/cost context
|
1405 |
+
if len(case_state.tests_performed) > 5:
|
1406 |
+
context_parts.append("EXTENSIVE TESTING: Many tests completed. Focus on synthesis rather than additional testing.")
|
1407 |
+
|
1408 |
+
return " | ".join(context_parts) if context_parts else ""
|
1409 |
+
|
1410 |
+
def _update_differential_from_hypothesis(self, case_state: CaseState, hypothesis_response):
|
1411 |
+
"""Extract and update differential diagnosis from Dr. Hypothesis analysis - now supports both function calls and text"""
|
1412 |
+
try:
|
1413 |
+
# Try to extract structured data from function call first
|
1414 |
+
if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
|
1415 |
+
structured_data = self._extract_function_call_output(hypothesis_response)
|
1416 |
+
|
1417 |
+
# Validate the structured data using the HypothesisArguments schema
|
1418 |
+
try:
|
1419 |
+
_ = HypothesisArguments(**structured_data)
|
1420 |
+
except ValidationError as e:
|
1421 |
+
logger.warning(f"HypothesisArguments validation failed: {e}")
|
1422 |
+
|
1423 |
+
# Check if we got structured differential data
|
1424 |
+
if "differential_diagnoses" in structured_data:
|
1425 |
+
# Update case state with structured data
|
1426 |
+
new_differential = {}
|
1427 |
+
for dx in structured_data["differential_diagnoses"]:
|
1428 |
+
new_differential[dx["diagnosis"]] = dx["probability"]
|
1429 |
+
|
1430 |
+
case_state.update_differential(new_differential)
|
1431 |
+
|
1432 |
+
# Update the main differential for backward compatibility
|
1433 |
+
summary = structured_data.get("summary", "Differential diagnosis updated")
|
1434 |
+
dx_text = f"{summary}\n\nTop Diagnoses:\n"
|
1435 |
+
for dx in structured_data["differential_diagnoses"]:
|
1436 |
+
dx_text += f"- {dx['diagnosis']}: {dx['probability']:.0%} - {dx['rationale']}\n"
|
1437 |
+
|
1438 |
+
if "key_evidence" in structured_data:
|
1439 |
+
dx_text += f"\nKey Evidence: {structured_data['key_evidence']}"
|
1440 |
+
if "contradictory_evidence" in structured_data:
|
1441 |
+
dx_text += f"\nContradictory Evidence: {structured_data['contradictory_evidence']}"
|
1442 |
+
|
1443 |
+
self.differential_diagnosis = dx_text
|
1444 |
+
logger.debug(f"Updated differential from function call: {new_differential}")
|
1445 |
+
return
|
1446 |
+
|
1447 |
+
# Fallback to text-based extraction
|
1448 |
+
hypothesis_text = str(hypothesis_response)
|
1449 |
+
if hasattr(hypothesis_response, 'content'):
|
1450 |
+
hypothesis_text = hypothesis_response.content
|
1451 |
+
|
1452 |
+
# Simple extraction - look for percentage patterns in the text
|
1453 |
+
import re
|
1454 |
+
|
1455 |
+
# Update the main differential diagnosis for backward compatibility
|
1456 |
+
self.differential_diagnosis = hypothesis_text
|
1457 |
+
|
1458 |
+
# Try to extract structured probabilities
|
1459 |
+
# Look for patterns like "Diagnosis: 85%" or "Disease (70%)"
|
1460 |
+
percentage_pattern = r'([A-Za-z][^:(\n]*?)[\s:]*[\(]?(\d{1,3})%[\)]?'
|
1461 |
+
matches = re.findall(percentage_pattern, hypothesis_text)
|
1462 |
+
|
1463 |
+
new_differential = {}
|
1464 |
+
for match in matches:
|
1465 |
+
diagnosis = match[0].strip().rstrip(':-()').strip()
|
1466 |
+
probability = float(match[1]) / 100.0
|
1467 |
+
if 0 <= probability <= 1.0 and len(diagnosis) > 3: # Basic validation
|
1468 |
+
new_differential[diagnosis] = probability
|
1469 |
+
|
1470 |
+
if new_differential:
|
1471 |
+
case_state.update_differential(new_differential)
|
1472 |
+
logger.debug(f"Updated differential from text parsing: {new_differential}")
|
1473 |
+
|
1474 |
+
except Exception as e:
|
1475 |
+
logger.debug(f"Could not extract structured differential: {e}")
|
1476 |
+
# Still update the text version for display
|
1477 |
+
hypothesis_text = str(hypothesis_response)
|
1478 |
+
if hasattr(hypothesis_response, 'content'):
|
1479 |
+
hypothesis_text = hypothesis_response.content
|
1480 |
+
self.differential_diagnosis = hypothesis_text
|
1481 |
+
|
1482 |
+
def _validate_and_correct_action(self, action: Action, case_state: CaseState, remaining_budget: int) -> Action:
|
1483 |
+
"""Validate and correct actions based on mode constraints and context"""
|
1484 |
+
|
1485 |
+
# Mode-specific validations
|
1486 |
+
if self.mode == "question_only" and action.action_type == "test":
|
1487 |
+
logger.warning("Test ordering attempted in question-only mode, converting to ask action")
|
1488 |
+
action.action_type = "ask"
|
1489 |
+
action.content = "Can you provide more details about the patient's symptoms and history?"
|
1490 |
+
action.reasoning = "Mode constraint: question-only mode active"
|
1491 |
+
|
1492 |
+
if self.mode == "budgeted" and action.action_type == "test" and remaining_budget <= 0:
|
1493 |
+
logger.warning("Test ordering attempted with insufficient budget, converting to diagnose action")
|
1494 |
+
action.action_type = "diagnose"
|
1495 |
+
action.content = case_state.get_leading_diagnosis()
|
1496 |
+
action.reasoning = "Budget constraint: insufficient funds for additional testing"
|
1497 |
+
|
1498 |
+
# Stagnation handling - ensure we have a valid diagnosis
|
1499 |
+
if case_state.is_stagnating(action):
|
1500 |
+
logger.warning("Stagnation detected, forcing diagnostic decision")
|
1501 |
+
action.action_type = "diagnose"
|
1502 |
+
leading_diagnosis = case_state.get_leading_diagnosis()
|
1503 |
+
# Ensure the diagnosis is meaningful, not corrupted
|
1504 |
+
if leading_diagnosis == "No diagnosis formulated" or len(leading_diagnosis) < 10 or any(char in leading_diagnosis for char in ['x10^9', 'β', '40β']):
|
1505 |
+
# Use a fallback diagnosis based on the case context
|
1506 |
+
action.content = "Unable to establish definitive diagnosis - further evaluation needed"
|
1507 |
+
else:
|
1508 |
+
action.content = leading_diagnosis
|
1509 |
+
action.reasoning = "Forced diagnosis due to detected stagnation in diagnostic process"
|
1510 |
+
|
1511 |
+
# High confidence threshold
|
1512 |
+
if action.action_type != "diagnose" and case_state.get_max_confidence() > 0.90:
|
1513 |
+
logger.info("Very high confidence reached, recommending diagnosis")
|
1514 |
+
action.action_type = "diagnose"
|
1515 |
+
action.content = case_state.get_leading_diagnosis()
|
1516 |
+
action.reasoning = "High confidence threshold reached - proceeding to final diagnosis"
|
1517 |
+
|
1518 |
+
return action
|
1519 |
|
1520 |
def _interact_with_gatekeeper(
|
1521 |
self, action: Action, full_case_details: str
|
|
|
1541 |
{request}
|
1542 |
"""
|
1543 |
|
1544 |
+
response = self._safe_agent_run(gatekeeper, prompt, agent_role=AgentRole.GATEKEEPER)
|
1545 |
return response
|
1546 |
|
1547 |
def _judge_diagnosis(
|
|
|
1553 |
Please evaluate the following diagnosis.
|
1554 |
Ground Truth: "{ground_truth}"
|
1555 |
Candidate Diagnosis: "{candidate_diagnosis}"
|
1556 |
+
|
1557 |
+
You must provide your evaluation in exactly this format:
|
1558 |
+
Score: [number from 1-5]
|
1559 |
+
Justification: [detailed reasoning for the score]
|
1560 |
"""
|
1561 |
+
response = self._safe_agent_run(judge, prompt, agent_role=AgentRole.JUDGE)
|
1562 |
+
|
1563 |
+
# Handle different response types from swarms Agent
|
1564 |
+
response_text = ""
|
1565 |
+
if hasattr(response, 'content'):
|
1566 |
+
response_text = response.content
|
1567 |
+
elif isinstance(response, dict):
|
1568 |
+
if 'role' in response and 'content' in response:
|
1569 |
+
response_text = response['content']
|
1570 |
+
elif 'content' in response:
|
1571 |
+
response_text = response['content']
|
1572 |
+
else:
|
1573 |
+
response_text = str(response)
|
1574 |
+
elif isinstance(response, str):
|
1575 |
+
response_text = response
|
1576 |
+
else:
|
1577 |
+
response_text = str(response)
|
1578 |
|
1579 |
+
# Enhanced parsing for demonstration; a more robust solution would use structured output.
|
1580 |
try:
|
1581 |
+
# Look for score patterns
|
1582 |
+
import re
|
1583 |
+
|
1584 |
+
# Try multiple score patterns
|
1585 |
+
score_patterns = [
|
1586 |
+
r"Score:\s*(\d+(?:\.\d+)?)",
|
1587 |
+
r"Score\s*(\d+(?:\.\d+)?)",
|
1588 |
+
r"(\d+(?:\.\d+)?)/5",
|
1589 |
+
r"Score.*?(\d+(?:\.\d+)?)",
|
1590 |
+
]
|
1591 |
+
|
1592 |
+
score = 0.0
|
1593 |
+
for pattern in score_patterns:
|
1594 |
+
match = re.search(pattern, response_text, re.IGNORECASE)
|
1595 |
+
if match:
|
1596 |
+
score = float(match.group(1))
|
1597 |
+
break
|
1598 |
+
|
1599 |
+
# Extract reasoning
|
1600 |
+
reasoning_patterns = [
|
1601 |
+
r"Justification:\s*(.+?)(?:\n\n|\Z)",
|
1602 |
+
r"Reasoning:\s*(.+?)(?:\n\n|\Z)",
|
1603 |
+
r"Explanation:\s*(.+?)(?:\n\n|\Z)",
|
1604 |
+
]
|
1605 |
+
|
1606 |
+
reasoning = "Could not parse judge's reasoning."
|
1607 |
+
for pattern in reasoning_patterns:
|
1608 |
+
match = re.search(pattern, response_text, re.IGNORECASE | re.DOTALL)
|
1609 |
+
if match:
|
1610 |
+
reasoning = match.group(1).strip()
|
1611 |
+
break
|
1612 |
+
|
1613 |
+
# If no specific reasoning found, use the whole response after score
|
1614 |
+
if reasoning == "Could not parse judge's reasoning." and score > 0:
|
1615 |
+
# Try to extract everything after the score
|
1616 |
+
score_match = re.search(r"Score:?\s*\d+(?:\.\d+)?", response_text, re.IGNORECASE)
|
1617 |
+
if score_match:
|
1618 |
+
reasoning = response_text[score_match.end():].strip()
|
1619 |
+
# Clean up common prefixes
|
1620 |
+
reasoning = re.sub(r"^(Justification|Reasoning|Explanation):\s*", "", reasoning, flags=re.IGNORECASE)
|
1621 |
+
|
1622 |
+
# Final fallback - use the whole response if we have a score
|
1623 |
+
if reasoning == "Could not parse judge's reasoning." and score > 0:
|
1624 |
+
reasoning = response_text
|
1625 |
+
|
1626 |
+
except (IndexError, ValueError, AttributeError) as e:
|
1627 |
+
logger.error(f"Error parsing judge response: {e}")
|
1628 |
score = 0.0
|
1629 |
+
reasoning = f"Could not parse judge's response: {str(e)}"
|
1630 |
|
1631 |
+
logger.info(f"Judge evaluation: Score={score}, Reasoning preview: {reasoning[:100]}...")
|
1632 |
return {"score": score, "reasoning": reasoning}
|
1633 |
|
1634 |
def run(
|
|
|
1638 |
ground_truth_diagnosis: str,
|
1639 |
) -> DiagnosisResult:
|
1640 |
"""
|
1641 |
+
Executes the full sequential diagnostic process with structured state management.
|
1642 |
|
1643 |
Args:
|
1644 |
initial_case_info (str): The initial abstract of the case.
|
|
|
1649 |
DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
|
1650 |
"""
|
1651 |
start_time = time.time()
|
1652 |
+
|
1653 |
+
# Initialize structured case state
|
1654 |
+
case_state = CaseState(initial_vignette=initial_case_info)
|
1655 |
+
case_state.cumulative_cost = self.physician_visit_cost # Add initial visit cost
|
1656 |
+
self.cumulative_cost = case_state.cumulative_cost
|
1657 |
+
|
1658 |
+
# Store for potential use by other methods
|
1659 |
+
self.case_state = case_state
|
1660 |
+
|
1661 |
+
# Add to conversation for history tracking
|
1662 |
self.conversation.add(
|
1663 |
"Gatekeeper",
|
1664 |
f"Initial Case Information: {initial_case_info}",
|
1665 |
)
|
1666 |
+
case_state.add_evidence(f"Initial presentation: {initial_case_info}")
|
1667 |
|
|
|
|
|
1668 |
logger.info(
|
1669 |
f"Initial physician visit cost: ${self.physician_visit_cost}"
|
1670 |
)
|
|
|
1674 |
|
1675 |
for i in range(self.max_iterations):
|
1676 |
iteration_count = i + 1
|
1677 |
+
case_state.iteration = iteration_count
|
1678 |
+
|
1679 |
logger.info(
|
1680 |
f"--- Starting Diagnostic Loop {iteration_count}/{self.max_iterations} ---"
|
1681 |
)
|
1682 |
logger.info(
|
1683 |
+
f"Current cost: ${case_state.cumulative_cost:,} | Remaining budget: ${self.initial_budget - case_state.cumulative_cost:,}"
|
1684 |
)
|
1685 |
|
1686 |
try:
|
1687 |
+
# Panel deliberates to decide on the next action using structured state
|
1688 |
+
action = self._run_panel_deliberation(case_state)
|
1689 |
logger.info(
|
1690 |
f"βοΈ Panel decision: {action.action_type.upper()} -> {action.content}"
|
1691 |
)
|
|
|
1693 |
f"π Medical reasoning: {action.reasoning}"
|
1694 |
)
|
1695 |
|
1696 |
+
# Add action to case state for stagnation detection
|
1697 |
+
case_state.add_action(action)
|
1698 |
+
|
1699 |
if action.action_type == "diagnose":
|
1700 |
final_diagnosis = action.content
|
1701 |
logger.info(
|
|
|
1703 |
)
|
1704 |
break
|
1705 |
|
1706 |
+
# Handle mode-specific constraints (most are now handled in validation)
|
1707 |
if (
|
1708 |
self.mode == "question_only"
|
1709 |
and action.action_type == "test"
|
|
|
1722 |
action.content
|
1723 |
)
|
1724 |
if (
|
1725 |
+
case_state.cumulative_cost + estimated_test_cost
|
1726 |
> self.initial_budget
|
1727 |
):
|
1728 |
logger.warning(
|
|
|
1735 |
action, full_case_details
|
1736 |
)
|
1737 |
self.conversation.add("Gatekeeper", response)
|
1738 |
+
case_state.add_evidence(response)
|
1739 |
|
1740 |
+
# Update costs and state based on action type
|
1741 |
if action.action_type == "test":
|
1742 |
test_cost = self._estimate_cost(action.content)
|
1743 |
+
case_state.cumulative_cost += test_cost
|
1744 |
+
case_state.add_test(str(action.content))
|
1745 |
+
self.cumulative_cost = case_state.cumulative_cost # Keep backward compatibility
|
1746 |
+
|
1747 |
logger.info(f"Tests ordered: {action.content}")
|
1748 |
logger.info(
|
1749 |
+
f"Test cost: ${test_cost:,} | Cumulative cost: ${case_state.cumulative_cost:,}"
|
1750 |
)
|
1751 |
elif action.action_type == "ask":
|
1752 |
+
case_state.add_question(str(action.content))
|
1753 |
# Questions are part of the same visit until tests are ordered
|
1754 |
logger.info(f"Questions asked: {action.content}")
|
1755 |
logger.info(
|
|
|
1759 |
# Check budget constraints for budgeted mode
|
1760 |
if (
|
1761 |
self.mode == "budgeted"
|
1762 |
+
and case_state.cumulative_cost >= self.initial_budget
|
1763 |
):
|
1764 |
logger.warning(
|
1765 |
"Budget limit reached. Forcing final diagnosis."
|
1766 |
)
|
1767 |
+
# Use current leading diagnosis
|
1768 |
+
final_diagnosis = case_state.get_leading_diagnosis()
|
|
|
|
|
|
|
|
|
1769 |
break
|
1770 |
|
1771 |
except Exception as e:
|
|
|
1777 |
|
1778 |
else:
|
1779 |
# Max iterations reached without diagnosis
|
1780 |
+
final_diagnosis = case_state.get_leading_diagnosis()
|
1781 |
+
if final_diagnosis == "No diagnosis formulated":
|
1782 |
+
final_diagnosis = "Diagnosis not reached within maximum iterations."
|
|
|
|
|
1783 |
logger.warning(
|
1784 |
f"Max iterations ({self.max_iterations}) reached. Using best available diagnosis."
|
1785 |
)
|
|
|
1815 |
ground_truth=ground_truth_diagnosis,
|
1816 |
accuracy_score=judgement["score"],
|
1817 |
accuracy_reasoning=judgement["reasoning"],
|
1818 |
+
total_cost=case_state.cumulative_cost,
|
1819 |
iterations=iteration_count,
|
1820 |
conversation_history=self.conversation.get_str(),
|
1821 |
)
|
|
|
1824 |
logger.info(f" Final diagnosis: {final_diagnosis}")
|
1825 |
logger.info(f" Ground truth: {ground_truth_diagnosis}")
|
1826 |
logger.info(f" Accuracy score: {judgement['score']}/5.0")
|
1827 |
+
logger.info(f" Total cost: ${case_state.cumulative_cost:,}")
|
1828 |
logger.info(f" Iterations: {iteration_count}")
|
1829 |
|
1830 |
return result
|
|
|
1971 |
print_on=True, # Enable printing for aggregator agent
|
1972 |
)
|
1973 |
|
1974 |
+
agg_resp = self._safe_agent_run(aggregator, aggregator_prompt)
|
1975 |
+
if hasattr(agg_resp, "content"):
|
1976 |
+
return agg_resp.content.strip()
|
1977 |
+
return str(agg_resp).strip()
|
1978 |
|
1979 |
except Exception as e:
|
1980 |
logger.error(f"Error in ensemble aggregation: {e}")
|
|
|
2012 |
"mode": "budgeted",
|
2013 |
"max_iterations": 10,
|
2014 |
"enable_budget_tracking": True,
|
2015 |
+
"initial_budget": kwargs.get("budget", 5000), # Fixed: map budget to initial_budget
|
2016 |
},
|
2017 |
"no_budget": {
|
2018 |
"mode": "no_budget",
|
|
|
2033 |
|
2034 |
config = variant_configs[variant]
|
2035 |
config.update(kwargs) # Allow overrides
|
2036 |
+
|
2037 |
+
# Remove 'budget' parameter if present, as it's mapped to 'initial_budget'
|
2038 |
+
config.pop('budget', None)
|
2039 |
|
2040 |
return cls(**config)
|
2041 |
|
2042 |
+
# ------------------------------------------------------------------
|
2043 |
+
# Helper utilities β throttling & robust JSON parsing
|
2044 |
+
# ------------------------------------------------------------------
|
2045 |
+
|
2046 |
+
def _safe_agent_run(
|
2047 |
+
self,
|
2048 |
+
agent: "Agent", # type: ignore β forward reference
|
2049 |
+
prompt: str,
|
2050 |
+
retries: int = 3,
|
2051 |
+
agent_role: AgentRole = None,
|
2052 |
+
) -> Any:
|
2053 |
+
"""Safely call `agent.run` while respecting OpenAI rate-limits.
|
2054 |
+
|
2055 |
+
Features:
|
2056 |
+
1. Estimates token usage and provides guidance to agents for self-regulation
|
2057 |
+
2. Applies progressive delays to respect rate limits
|
2058 |
+
3. Lets agents dynamically adjust their response strategy based on token constraints
|
2059 |
+
"""
|
2060 |
+
|
2061 |
+
# Get agent role for token calculations
|
2062 |
+
if agent_role is None:
|
2063 |
+
agent_role = AgentRole.CONSENSUS # Default fallback
|
2064 |
+
|
2065 |
+
# Estimate total tokens in the request
|
2066 |
+
estimated_input_tokens = self._estimate_tokens(prompt)
|
2067 |
+
max_output_tokens = self._get_agent_max_tokens(agent_role)
|
2068 |
+
total_estimated_tokens = estimated_input_tokens + max_output_tokens
|
2069 |
+
|
2070 |
+
# Add dynamic token guidance to the prompt instead of truncating
|
2071 |
+
token_guidance = self._generate_token_guidance(
|
2072 |
+
estimated_input_tokens, max_output_tokens, total_estimated_tokens, agent_role
|
2073 |
+
)
|
2074 |
+
|
2075 |
+
# Prepend token guidance to prompt
|
2076 |
+
enhanced_prompt = f"{token_guidance}\n\n{prompt}"
|
2077 |
+
|
2078 |
+
logger.debug(f"Agent {agent_role.value}: Input={estimated_input_tokens}, Output={max_output_tokens}, Total={total_estimated_tokens}")
|
2079 |
+
|
2080 |
+
# Increased base delay for better rate limit compliance
|
2081 |
+
base_delay = max(self.request_delay, 5.0) # Minimum 5 seconds between requests
|
2082 |
+
|
2083 |
+
for attempt in range(retries + 1):
|
2084 |
+
# Progressive delay: 5s, 15s, 45s, 135s
|
2085 |
+
current_delay = base_delay * (3 ** attempt) if attempt > 0 else base_delay
|
2086 |
+
|
2087 |
+
logger.info(f"Request attempt {attempt + 1}/{retries + 1}, waiting {current_delay:.1f}s...")
|
2088 |
+
time.sleep(current_delay)
|
2089 |
+
|
2090 |
+
try:
|
2091 |
+
return agent.run(enhanced_prompt)
|
2092 |
+
except Exception as e:
|
2093 |
+
err_msg = str(e).lower()
|
2094 |
+
if "rate_limit" in err_msg or "ratelimiterror" in err_msg or "429" in str(e):
|
2095 |
+
logger.warning(
|
2096 |
+
f"Rate-limit encountered (attempt {attempt + 1}/{retries + 1}). "
|
2097 |
+
f"Will retry after {base_delay * (3 ** (attempt + 1)):.1f}s..."
|
2098 |
+
)
|
2099 |
+
continue # Next retry applies longer delay
|
2100 |
+
# For non-rate-limit errors, propagate immediately
|
2101 |
+
raise
|
2102 |
+
|
2103 |
+
# All retries exhausted
|
2104 |
+
raise RuntimeError("Maximum retries exceeded for agent.run β aborting call")
|
2105 |
+
|
2106 |
+
def _robust_parse_action(self, raw_response: str) -> Dict[str, Any]:
|
2107 |
+
"""Extract a JSON *action* object from `raw_response`.
|
2108 |
+
|
2109 |
+
The function tries multiple strategies and finally returns a default
|
2110 |
+
*ask* action if no valid JSON can be located.
|
2111 |
+
"""
|
2112 |
+
|
2113 |
+
import json, re
|
2114 |
+
|
2115 |
+
# Strip common markdown fences
|
2116 |
+
if raw_response.strip().startswith("```"):
|
2117 |
+
segments = raw_response.split("```")
|
2118 |
+
for seg in segments:
|
2119 |
+
seg = seg.strip()
|
2120 |
+
if seg.startswith("{") and seg.endswith("}"):
|
2121 |
+
raw_response = seg
|
2122 |
+
break
|
2123 |
+
|
2124 |
+
# 1) Fast path β direct JSON decode
|
2125 |
+
try:
|
2126 |
+
data = json.loads(raw_response)
|
2127 |
+
if isinstance(data, dict) and "action_type" in data:
|
2128 |
+
return data
|
2129 |
+
except Exception:
|
2130 |
+
pass
|
2131 |
+
|
2132 |
+
# 2) Regex search for the first balanced curly block
|
2133 |
+
match = re.search(r"\{[\s\S]*?\}", raw_response)
|
2134 |
+
if match:
|
2135 |
+
candidate = match.group(0)
|
2136 |
+
# Remove leading drawing characters (e.g., table borders)
|
2137 |
+
candidate = "\n".join(line.lstrip("β| ").rstrip("β| ") for line in candidate.splitlines())
|
2138 |
+
try:
|
2139 |
+
data = json.loads(candidate)
|
2140 |
+
if isinstance(data, dict) and "action_type" in data:
|
2141 |
+
return data
|
2142 |
+
except Exception:
|
2143 |
+
pass
|
2144 |
+
|
2145 |
+
logger.error("Failed to parse a valid action JSON. Falling back to default ask action")
|
2146 |
+
return {
|
2147 |
+
"action_type": "ask",
|
2148 |
+
"content": "Could you please clarify the next best step? The previous analysis was inconclusive.",
|
2149 |
+
"reasoning": "Fallback generated due to JSON parsing failure.",
|
2150 |
+
}
|
2151 |
+
|
2152 |
+
def _extract_function_call_output(self, agent_response) -> Dict[str, Any]:
|
2153 |
+
"""Extract structured output from agent function call response.
|
2154 |
+
|
2155 |
+
This method handles the swarms Agent response format when using function calling.
|
2156 |
+
The response should contain tool calls with the structured data.
|
2157 |
+
"""
|
2158 |
+
try:
|
2159 |
+
# Handle different response formats from swarms Agent
|
2160 |
+
if isinstance(agent_response, dict):
|
2161 |
+
# Check for tool calls in the response
|
2162 |
+
if "tool_calls" in agent_response and agent_response["tool_calls"]:
|
2163 |
+
tool_call = agent_response["tool_calls"][0] # Get first tool call
|
2164 |
+
if "function" in tool_call and "arguments" in tool_call["function"]:
|
2165 |
+
arguments = tool_call["function"]["arguments"]
|
2166 |
+
if isinstance(arguments, str):
|
2167 |
+
# Parse JSON string arguments
|
2168 |
+
import json
|
2169 |
+
arguments = json.loads(arguments)
|
2170 |
+
return arguments
|
2171 |
+
|
2172 |
+
# Check for direct arguments in response
|
2173 |
+
if "arguments" in agent_response:
|
2174 |
+
arguments = agent_response["arguments"]
|
2175 |
+
if isinstance(arguments, str):
|
2176 |
+
import json
|
2177 |
+
arguments = json.loads(arguments)
|
2178 |
+
return arguments
|
2179 |
+
|
2180 |
+
# Check if response itself has the expected structure
|
2181 |
+
if all(key in agent_response for key in ["action_type", "content", "reasoning"]):
|
2182 |
+
return {
|
2183 |
+
"action_type": agent_response["action_type"],
|
2184 |
+
"content": agent_response["content"],
|
2185 |
+
"reasoning": agent_response["reasoning"]
|
2186 |
+
}
|
2187 |
+
|
2188 |
+
# Handle Agent object response
|
2189 |
+
elif hasattr(agent_response, "__dict__"):
|
2190 |
+
# Check for tool_calls attribute
|
2191 |
+
if hasattr(agent_response, "tool_calls") and agent_response.tool_calls:
|
2192 |
+
tool_call = agent_response.tool_calls[0]
|
2193 |
+
if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
|
2194 |
+
arguments = tool_call.function.arguments
|
2195 |
+
if isinstance(arguments, str):
|
2196 |
+
import json
|
2197 |
+
arguments = json.loads(arguments)
|
2198 |
+
return arguments
|
2199 |
+
|
2200 |
+
# Check for direct function call response
|
2201 |
+
if hasattr(agent_response, "function_call"):
|
2202 |
+
function_call = agent_response.function_call
|
2203 |
+
if hasattr(function_call, "arguments"):
|
2204 |
+
arguments = function_call.arguments
|
2205 |
+
if isinstance(arguments, str):
|
2206 |
+
import json
|
2207 |
+
arguments = json.loads(arguments)
|
2208 |
+
return arguments
|
2209 |
+
|
2210 |
+
# Try to extract from response content
|
2211 |
+
if hasattr(agent_response, "content"):
|
2212 |
+
content = agent_response.content
|
2213 |
+
if isinstance(content, dict) and all(key in content for key in ["action_type", "content", "reasoning"]):
|
2214 |
+
return content
|
2215 |
+
|
2216 |
+
# Handle string response (fallback to regex parsing)
|
2217 |
+
elif isinstance(agent_response, str):
|
2218 |
+
# Try to parse as JSON first
|
2219 |
+
try:
|
2220 |
+
import json
|
2221 |
+
parsed = json.loads(agent_response)
|
2222 |
+
if isinstance(parsed, dict) and all(key in parsed for key in ["action_type", "content", "reasoning"]):
|
2223 |
+
return parsed
|
2224 |
+
except:
|
2225 |
+
pass
|
2226 |
+
|
2227 |
+
# Fallback to regex extraction
|
2228 |
+
import re
|
2229 |
+
action_type_match = re.search(r'"action_type":\s*"(ask|test|diagnose)"', agent_response, re.IGNORECASE)
|
2230 |
+
content_match = re.search(r'"content":\s*"([^"]+)"', agent_response, re.IGNORECASE | re.DOTALL)
|
2231 |
+
reasoning_match = re.search(r'"reasoning":\s*"([^"]+)"', agent_response, re.IGNORECASE | re.DOTALL)
|
2232 |
+
|
2233 |
+
if action_type_match and content_match and reasoning_match:
|
2234 |
+
return {
|
2235 |
+
"action_type": action_type_match.group(1).lower(),
|
2236 |
+
"content": content_match.group(1).strip(),
|
2237 |
+
"reasoning": reasoning_match.group(1).strip()
|
2238 |
+
}
|
2239 |
+
|
2240 |
+
logger.warning(f"Could not extract function call output from response type: {type(agent_response)}")
|
2241 |
+
logger.debug(f"Response content: {str(agent_response)[:500]}...")
|
2242 |
+
|
2243 |
+
except Exception as e:
|
2244 |
+
logger.error(f"Error extracting function call output: {e}")
|
2245 |
+
logger.debug(f"Response: {str(agent_response)[:500]}...")
|
2246 |
+
|
2247 |
+
# Final fallback
|
2248 |
+
return {
|
2249 |
+
"action_type": "ask",
|
2250 |
+
"content": "Could you please provide more information to help guide the next diagnostic step?",
|
2251 |
+
"reasoning": "Fallback action due to function call parsing error."
|
2252 |
+
}
|
2253 |
+
|
2254 |
+
def _get_consensus_with_retry(self, consensus_prompt: str, max_retries: int = 2) -> Dict[str, Any]:
|
2255 |
+
"""Get consensus decision with function call retry logic."""
|
2256 |
+
|
2257 |
+
for attempt in range(max_retries + 1):
|
2258 |
+
try:
|
2259 |
+
if attempt == 0:
|
2260 |
+
# First attempt - use original prompt
|
2261 |
+
response = self._safe_agent_run(
|
2262 |
+
self.agents[AgentRole.CONSENSUS], consensus_prompt, agent_role=AgentRole.CONSENSUS
|
2263 |
+
)
|
2264 |
+
else:
|
2265 |
+
# Retry with explicit function call instruction
|
2266 |
+
retry_prompt = f"""
|
2267 |
+
{consensus_prompt}
|
2268 |
+
|
2269 |
+
**CRITICAL: RETRY ATTEMPT {attempt}**
|
2270 |
+
Your previous response failed to use the required `make_consensus_decision` function.
|
2271 |
+
You MUST call the make_consensus_decision function with the appropriate parameters:
|
2272 |
+
- action_type: "ask", "test", or "diagnose"
|
2273 |
+
- content: specific question, test name, or diagnosis
|
2274 |
+
- reasoning: your detailed reasoning
|
2275 |
+
|
2276 |
+
Please try again and ensure you call the function correctly.
|
2277 |
+
"""
|
2278 |
+
response = self._safe_agent_run(
|
2279 |
+
self.agents[AgentRole.CONSENSUS], retry_prompt, agent_role=AgentRole.CONSENSUS
|
2280 |
+
)
|
2281 |
+
|
2282 |
+
logger.debug(f"Consensus attempt {attempt + 1}, response type: {type(response)}")
|
2283 |
+
|
2284 |
+
# Try to extract function call output
|
2285 |
+
action_dict = self._extract_function_call_output(response)
|
2286 |
+
|
2287 |
+
# Validate and enforce schema using ConsensusArguments for type safety
|
2288 |
+
try:
|
2289 |
+
validated_args = ConsensusArguments(**action_dict)
|
2290 |
+
action_dict = validated_args.dict()
|
2291 |
+
except ValidationError as e:
|
2292 |
+
logger.warning(f"ConsensusArguments validation failed: {e}")
|
2293 |
+
|
2294 |
+
# Check if we got a valid response (not a fallback)
|
2295 |
+
if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
|
2296 |
+
logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
|
2297 |
+
return action_dict
|
2298 |
+
|
2299 |
+
logger.warning(f"Function call failed on attempt {attempt + 1}, will retry")
|
2300 |
+
|
2301 |
+
except Exception as e:
|
2302 |
+
logger.error(f"Error in consensus attempt {attempt + 1}: {e}")
|
2303 |
+
|
2304 |
+
# Final fallback to JSON parsing if all function call attempts failed
|
2305 |
+
logger.warning("All function call attempts failed, falling back to JSON parsing")
|
2306 |
+
try:
|
2307 |
+
# Use the last response and try JSON parsing
|
2308 |
+
consensus_text = (
|
2309 |
+
response.content if hasattr(response, "content") else str(response)
|
2310 |
+
)
|
2311 |
+
return self._robust_parse_action(consensus_text)
|
2312 |
+
except Exception as e:
|
2313 |
+
logger.error(f"Both function calling and JSON parsing failed: {e}")
|
2314 |
+
return {
|
2315 |
+
"action_type": "ask",
|
2316 |
+
"content": "Could you please provide more information to guide the diagnostic process?",
|
2317 |
+
"reasoning": f"Final fallback after {max_retries + 1} function call attempts and JSON parsing failure."
|
2318 |
+
}
|
2319 |
+
|
2320 |
|
2321 |
def run_mai_dxo_demo(
|
2322 |
case_info: str = None,
|
|
|
2371 |
orchestrator = MaiDxOrchestrator.create_variant(
|
2372 |
variant,
|
2373 |
budget=3000,
|
2374 |
+
model_name="gemini/gemini-2.5-flash", # Fixed: Use valid model name
|
2375 |
+
max_iterations=3,
|
2376 |
)
|
2377 |
else:
|
2378 |
orchestrator = MaiDxOrchestrator.create_variant(
|
2379 |
+
variant,
|
2380 |
+
model_name="gemini/gemini-2.5-flash", # Fixed: Use valid model name
|
2381 |
+
max_iterations=3,
|
2382 |
)
|
2383 |
|
2384 |
result = orchestrator.run(
|
|
|
2453 |
# orchestrator = MaiDxOrchestrator.create_variant(
|
2454 |
# variant_name,
|
2455 |
# budget=3000,
|
2456 |
+
# model_name="gpt-4.1", # Fixed: Use valid model name
|
2457 |
+
# max_iterations=3,
|
2458 |
# )
|
2459 |
# else:
|
2460 |
# orchestrator = MaiDxOrchestrator.create_variant(
|
2461 |
# variant_name,
|
2462 |
+
# model_name="gpt-4.1", # Fixed: Use valid model name
|
2463 |
+
# max_iterations=3,
|
2464 |
# )
|
2465 |
|
2466 |
# # Run the diagnostic process
|
|
|
2491 |
|
2492 |
# ensemble_orchestrator = MaiDxOrchestrator.create_variant(
|
2493 |
# "ensemble",
|
2494 |
+
# model_name="gpt-4.1", # Fixed: Use valid model name
|
2495 |
# max_iterations=3, # Shorter iterations for ensemble
|
2496 |
# )
|
2497 |
|