Spaces:
Sleeping
Sleeping
Kye Gomez
commited on
Commit
Β·
69f6ed5
1
Parent(s):
5239f6e
__init__ and requirements.txr
Browse files- mai_dx/__init__.py +3 -0
- mai_dx/main.py +601 -298
- requirements.txt +2 -2
mai_dx/__init__.py
CHANGED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from mai_dx.main import MaiDxOrchestrator, run_mai_dxo_demo
|
2 |
+
|
3 |
+
__all__ = ["MaiDxOrchestrator", "run_mai_dxo_demo"]
|
mai_dx/main.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
"""
|
2 |
MAI Diagnostic Orchestrator (MAI-DxO)
|
3 |
|
4 |
-
This script provides a complete implementation of the "Sequential Diagnosis with Language Models"
|
5 |
-
paper, using the `swarms` framework. It simulates a virtual panel of physician-agents to perform
|
6 |
iterative medical diagnosis with cost-effectiveness optimization.
|
7 |
|
8 |
Based on the paper: "Sequential Diagnosis with Language Models"
|
@@ -20,20 +20,22 @@ Example Usage:
|
|
20 |
# Standard MAI-DxO usage
|
21 |
orchestrator = MaiDxOrchestrator(model_name="gemini/gemini-2.5-flash")
|
22 |
result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
|
23 |
-
|
24 |
# Budget-constrained variant
|
25 |
budgeted_orchestrator = MaiDxOrchestrator.create_variant("budgeted", budget=5000)
|
26 |
-
|
27 |
# Ensemble approach
|
28 |
ensemble_result = orchestrator.run_ensemble(initial_case_info, full_case_details, ground_truth)
|
29 |
"""
|
30 |
|
|
|
|
|
31 |
import json
|
32 |
import sys
|
33 |
import time
|
34 |
from dataclasses import dataclass
|
35 |
from enum import Enum
|
36 |
-
from typing import Any, Dict, List,
|
37 |
|
38 |
from loguru import logger
|
39 |
from pydantic import BaseModel, Field
|
@@ -47,26 +49,27 @@ logger.add(
|
|
47 |
sys.stdout,
|
48 |
level="INFO",
|
49 |
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
50 |
-
colorize=True
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
import os
|
55 |
if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
|
56 |
logger.add(
|
57 |
"logs/maidx_debug_{time:YYYY-MM-DD}.log",
|
58 |
level="DEBUG",
|
59 |
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
60 |
rotation="1 day",
|
61 |
-
retention="3 days"
|
|
|
|
|
|
|
62 |
)
|
63 |
-
logger.info("π Debug logging enabled - logs will be written to logs/ directory")
|
64 |
|
65 |
# File handler for persistent logging (optional - uncomment if needed)
|
66 |
# logger.add(
|
67 |
# "logs/mai_dxo_{time:YYYY-MM-DD}.log",
|
68 |
# rotation="1 day",
|
69 |
-
# retention="7 days",
|
70 |
# level="DEBUG",
|
71 |
# format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
72 |
# compression="zip"
|
@@ -74,8 +77,10 @@ if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
|
|
74 |
|
75 |
# --- Data Structures and Enums ---
|
76 |
|
|
|
77 |
class AgentRole(Enum):
|
78 |
"""Enumeration of roles for the virtual physician panel."""
|
|
|
79 |
HYPOTHESIS = "Dr. Hypothesis"
|
80 |
TEST_CHOOSER = "Dr. Test-Chooser"
|
81 |
CHALLENGER = "Dr. Challenger"
|
@@ -85,9 +90,11 @@ class AgentRole(Enum):
|
|
85 |
GATEKEEPER = "Gatekeeper"
|
86 |
JUDGE = "Judge"
|
87 |
|
|
|
88 |
@dataclass
|
89 |
class DiagnosisResult:
|
90 |
"""Stores the final result of a diagnostic session."""
|
|
|
91 |
final_diagnosis: str
|
92 |
ground_truth: str
|
93 |
accuracy_score: float
|
@@ -96,20 +103,32 @@ class DiagnosisResult:
|
|
96 |
iterations: int
|
97 |
conversation_history: str
|
98 |
|
|
|
99 |
class Action(BaseModel):
|
100 |
"""Pydantic model for a structured action decided by the consensus agent."""
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# --- Main Orchestrator Class ---
|
106 |
|
|
|
107 |
class MaiDxOrchestrator:
|
108 |
"""
|
109 |
Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
|
110 |
This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
|
111 |
evaluates the final diagnosis, and tracks costs.
|
112 |
"""
|
|
|
113 |
def __init__(
|
114 |
self,
|
115 |
model_name: str = "gemini/gemini-2.5-flash",
|
@@ -136,15 +155,13 @@ class MaiDxOrchestrator:
|
|
136 |
self.mode = mode
|
137 |
self.physician_visit_cost = physician_visit_cost
|
138 |
self.enable_budget_tracking = enable_budget_tracking
|
139 |
-
|
140 |
self.cumulative_cost = 0
|
141 |
self.differential_diagnosis = "Not yet formulated."
|
142 |
self.conversation = Conversation(
|
143 |
-
time_enabled=True,
|
144 |
-
autosave=False,
|
145 |
-
save_enabled=False
|
146 |
)
|
147 |
-
|
148 |
# Enhanced cost model based on the paper's methodology
|
149 |
self.test_cost_db = {
|
150 |
"default": 150,
|
@@ -177,7 +194,9 @@ class MaiDxOrchestrator:
|
|
177 |
}
|
178 |
|
179 |
self._init_agents()
|
180 |
-
logger.info(
|
|
|
|
|
181 |
|
182 |
def _init_agents(self):
|
183 |
"""Initializes all required agents with their specific roles and prompts."""
|
@@ -187,16 +206,22 @@ class MaiDxOrchestrator:
|
|
187 |
system_prompt=self._get_prompt_for_role(role),
|
188 |
model_name=self.model_name,
|
189 |
max_loops=1,
|
190 |
-
output_type=
|
|
|
|
|
191 |
print_on=True, # Enable printing for all agents to see outputs
|
192 |
-
)
|
|
|
193 |
}
|
194 |
-
logger.info(
|
|
|
|
|
195 |
|
196 |
def _get_prompt_for_role(self, role: AgentRole) -> str:
|
197 |
"""Returns the system prompt for a given agent role."""
|
198 |
prompts = {
|
199 |
-
AgentRole.HYPOTHESIS:
|
|
|
200 |
You are Dr. Hypothesis, a specialist in maintaining differential diagnoses. Your role is critical to the diagnostic process.
|
201 |
|
202 |
CORE RESPONSIBILITIES:
|
@@ -221,9 +246,10 @@ class MaiDxOrchestrator:
|
|
221 |
- Evidence that contradicts or challenges each hypothesis
|
222 |
|
223 |
Remember: Your differential drives the entire diagnostic process. Be thorough, evidence-based, and adaptive.
|
224 |
-
"""
|
225 |
-
|
226 |
-
AgentRole.TEST_CHOOSER:
|
|
|
227 |
You are Dr. Test-Chooser, a specialist in diagnostic test selection and information theory.
|
228 |
|
229 |
CORE RESPONSIBILITIES:
|
@@ -252,9 +278,10 @@ class MaiDxOrchestrator:
|
|
252 |
- How results will change management decisions
|
253 |
|
254 |
Focus on tests that will most efficiently narrow the differential diagnosis.
|
255 |
-
"""
|
256 |
-
|
257 |
-
AgentRole.CHALLENGER:
|
|
|
258 |
You are Dr. Challenger, the critical thinking specialist and devil's advocate.
|
259 |
|
260 |
CORE RESPONSIBILITIES:
|
@@ -285,9 +312,10 @@ class MaiDxOrchestrator:
|
|
285 |
- Red flags or concerning patterns that need attention
|
286 |
|
287 |
Be constructively critical - your role is to strengthen diagnostic accuracy through rigorous challenge.
|
288 |
-
"""
|
289 |
-
|
290 |
-
AgentRole.STEWARDSHIP:
|
|
|
291 |
You are Dr. Stewardship, the resource optimization and cost-effectiveness specialist.
|
292 |
|
293 |
CORE RESPONSIBILITIES:
|
@@ -323,9 +351,10 @@ class MaiDxOrchestrator:
|
|
323 |
- Cumulative cost considerations
|
324 |
|
325 |
Your goal: Maximum diagnostic accuracy at minimum necessary cost.
|
326 |
-
"""
|
327 |
-
|
328 |
-
AgentRole.CHECKLIST:
|
|
|
329 |
You are Dr. Checklist, the quality assurance and consistency specialist.
|
330 |
|
331 |
CORE RESPONSIBILITIES:
|
@@ -356,9 +385,10 @@ class MaiDxOrchestrator:
|
|
356 |
- Process improvement suggestions
|
357 |
|
358 |
Keep your feedback concise but comprehensive. Flag any issues that could compromise diagnostic quality.
|
359 |
-
"""
|
360 |
-
|
361 |
-
AgentRole.CONSENSUS:
|
|
|
362 |
You are the Consensus Coordinator, responsible for synthesizing the virtual panel's expertise into a single, optimal decision.
|
363 |
|
364 |
CORE RESPONSIBILITIES:
|
@@ -392,9 +422,10 @@ class MaiDxOrchestrator:
|
|
392 |
For action_type "diagnose": content should be the complete, specific final diagnosis
|
393 |
|
394 |
Make the decision that best advances accurate, cost-effective diagnosis.
|
395 |
-
"""
|
396 |
-
|
397 |
-
AgentRole.GATEKEEPER:
|
|
|
398 |
You are the Gatekeeper, the clinical information oracle with complete access to the patient case file.
|
399 |
|
400 |
CORE RESPONSIBILITIES:
|
@@ -431,9 +462,10 @@ class MaiDxOrchestrator:
|
|
431 |
- Professional medical terminology
|
432 |
|
433 |
Your role is crucial: provide complete, accurate clinical information while maintaining the challenge of the diagnostic process.
|
434 |
-
"""
|
435 |
-
|
436 |
-
AgentRole.JUDGE:
|
|
|
437 |
You are the Judge, the diagnostic accuracy evaluation specialist.
|
438 |
|
439 |
CORE RESPONSIBILITIES:
|
@@ -489,9 +521,10 @@ class MaiDxOrchestrator:
|
|
489 |
|
490 |
Maintain high standards while recognizing legitimate diagnostic variability in medical practice.
|
491 |
"""
|
|
|
492 |
}
|
493 |
return prompts[role]
|
494 |
-
|
495 |
def _parse_json_response(self, response: str) -> Dict[str, Any]:
|
496 |
"""Safely parses a JSON string, returning a dictionary."""
|
497 |
try:
|
@@ -507,74 +540,93 @@ class MaiDxOrchestrator:
|
|
507 |
start_idx += len(start_marker)
|
508 |
end_idx = response.find(end_marker, start_idx)
|
509 |
if end_idx != -1:
|
510 |
-
json_content = response[
|
|
|
|
|
511 |
return json.loads(json_content)
|
512 |
-
|
513 |
# Try to find JSON-like content in the response
|
514 |
-
lines = response.split(
|
515 |
json_lines = []
|
516 |
in_json = False
|
517 |
brace_count = 0
|
518 |
-
|
519 |
for line in lines:
|
520 |
stripped_line = line.strip()
|
521 |
-
if stripped_line.startswith(
|
522 |
in_json = True
|
523 |
json_lines = [line] # Start fresh
|
524 |
-
brace_count = line.count(
|
|
|
|
|
525 |
elif in_json:
|
526 |
json_lines.append(line)
|
527 |
-
brace_count += line.count(
|
528 |
-
|
|
|
|
|
|
|
|
|
529 |
break
|
530 |
-
|
531 |
if json_lines and in_json:
|
532 |
-
json_content =
|
533 |
return json.loads(json_content)
|
534 |
-
|
535 |
# Try to extract JSON from text that might contain other content
|
536 |
import re
|
|
|
537 |
# Look for JSON pattern in the text
|
538 |
-
json_pattern = r
|
539 |
-
matches = re.findall(
|
540 |
-
|
|
|
|
|
541 |
for match in matches:
|
542 |
try:
|
543 |
return json.loads(match)
|
544 |
except json.JSONDecodeError:
|
545 |
continue
|
546 |
-
|
547 |
# Direct parsing attempt as fallback
|
548 |
return json.loads(response)
|
549 |
-
|
550 |
-
except (
|
|
|
|
|
|
|
|
|
551 |
logger.error(f"Failed to parse JSON response. Error: {e}")
|
552 |
-
logger.debug(
|
|
|
|
|
553 |
# Fallback to a default action if parsing fails
|
554 |
return {
|
555 |
"action_type": "ask",
|
556 |
-
"content":
|
557 |
-
|
|
|
|
|
558 |
}
|
559 |
-
|
560 |
def _estimate_cost(self, tests: Union[List[str], str]) -> int:
|
561 |
"""Estimates the cost of diagnostic tests."""
|
562 |
if isinstance(tests, str):
|
563 |
tests = [tests]
|
564 |
-
|
565 |
cost = 0
|
566 |
for test in tests:
|
567 |
test_lower = test.lower().strip()
|
568 |
-
|
569 |
# Enhanced cost matching with multiple strategies
|
570 |
cost_found = False
|
571 |
-
|
572 |
# Strategy 1: Exact match
|
573 |
if test_lower in self.test_cost_db:
|
574 |
cost += self.test_cost_db[test_lower]
|
575 |
cost_found = True
|
576 |
continue
|
577 |
-
|
578 |
# Strategy 2: Partial match (find best matching key)
|
579 |
best_match = None
|
580 |
best_match_length = 0
|
@@ -583,56 +635,87 @@ class MaiDxOrchestrator:
|
|
583 |
if len(cost_key) > best_match_length:
|
584 |
best_match = cost_key
|
585 |
best_match_length = len(cost_key)
|
586 |
-
|
587 |
if best_match:
|
588 |
cost += self.test_cost_db[best_match]
|
589 |
cost_found = True
|
590 |
continue
|
591 |
-
|
592 |
# Strategy 3: Keyword-based matching
|
593 |
-
if any(
|
594 |
-
|
|
|
|
|
|
|
595 |
cost_found = True
|
596 |
-
elif any(
|
597 |
-
|
|
|
|
|
|
|
598 |
cost_found = True
|
599 |
-
elif any(
|
600 |
-
|
|
|
|
|
|
|
601 |
cost_found = True
|
602 |
-
elif any(
|
603 |
-
|
|
|
|
|
|
|
604 |
cost_found = True
|
605 |
-
elif any(
|
|
|
|
|
|
|
606 |
cost += 100 # Basic blood test cost
|
607 |
cost_found = True
|
608 |
-
elif any(
|
609 |
-
|
|
|
|
|
|
|
610 |
cost_found = True
|
611 |
-
elif any(
|
612 |
-
|
|
|
|
|
|
|
|
|
|
|
613 |
cost_found = True
|
614 |
-
|
615 |
# Strategy 4: Default cost for unknown tests
|
616 |
if not cost_found:
|
617 |
-
cost += self.test_cost_db[
|
618 |
-
logger.debug(
|
619 |
-
|
|
|
|
|
620 |
return cost
|
621 |
|
622 |
def _run_panel_deliberation(self) -> Action:
|
623 |
"""Orchestrates one round of debate among the virtual panel to decide the next action."""
|
624 |
-
logger.info(
|
625 |
-
|
|
|
|
|
|
|
|
|
626 |
panel_conversation = Conversation(
|
627 |
-
time_enabled=True,
|
628 |
-
autosave=False,
|
629 |
-
save_enabled=False
|
630 |
)
|
631 |
-
|
632 |
# Prepare comprehensive panel context
|
633 |
remaining_budget = self.initial_budget - self.cumulative_cost
|
634 |
-
budget_status =
|
635 |
-
|
|
|
|
|
|
|
|
|
636 |
panel_context = f"""
|
637 |
DIAGNOSTIC CASE STATUS - ROUND {len(self.conversation.return_history_as_string().split('Gatekeeper:')) - 1}
|
638 |
|
@@ -657,11 +740,17 @@ class MaiDxOrchestrator:
|
|
657 |
# For instant mode, skip deliberation and go straight to diagnosis
|
658 |
action_dict = {
|
659 |
"action_type": "diagnose",
|
660 |
-
"content":
|
661 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
}
|
663 |
return Action(**action_dict)
|
664 |
-
|
665 |
if self.mode == "question_only":
|
666 |
# For question-only mode, prevent test ordering
|
667 |
panel_context += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY ask patient questions, not order diagnostic tests."
|
@@ -670,67 +759,127 @@ class MaiDxOrchestrator:
|
|
670 |
# Sequential expert deliberation with enhanced methodology
|
671 |
try:
|
672 |
# Dr. Hypothesis - Differential diagnosis and probability assessment
|
673 |
-
logger.info(
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
# Dr. Test-Chooser - Information value optimization
|
679 |
-
logger.info(
|
680 |
-
|
681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
682 |
|
683 |
# Dr. Challenger - Bias identification and alternative hypotheses
|
684 |
-
logger.info(
|
685 |
-
|
686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
|
688 |
# Dr. Stewardship - Cost-effectiveness analysis
|
689 |
-
logger.info(
|
|
|
|
|
690 |
stewardship_context = panel_conversation.get_str()
|
691 |
if self.enable_budget_tracking:
|
692 |
stewardship_context += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${self.cumulative_cost}, Remaining: ${remaining_budget}"
|
693 |
-
stewardship_rec = self.agents[AgentRole.STEWARDSHIP].run(
|
694 |
-
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
696 |
# Dr. Checklist - Quality assurance
|
697 |
-
logger.info(
|
698 |
-
|
699 |
-
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
# Consensus Coordinator - Final decision synthesis
|
702 |
-
logger.info(
|
|
|
|
|
703 |
consensus_context = panel_conversation.get_str()
|
704 |
-
|
705 |
# Add mode-specific constraints to consensus
|
706 |
if self.mode == "budgeted" and remaining_budget <= 0:
|
707 |
consensus_context += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
|
708 |
-
|
709 |
-
consensus_response = self.agents[AgentRole.CONSENSUS].run(
|
710 |
-
|
711 |
-
|
|
|
|
|
|
|
|
|
712 |
# Extract the actual text content from agent response
|
713 |
-
if hasattr(consensus_response,
|
714 |
response_text = consensus_response.content
|
715 |
elif isinstance(consensus_response, str):
|
716 |
response_text = consensus_response
|
717 |
else:
|
718 |
response_text = str(consensus_response)
|
719 |
-
|
720 |
action_dict = self._parse_json_response(response_text)
|
721 |
|
722 |
# Validate action based on mode constraints
|
723 |
action = Action(**action_dict)
|
724 |
-
if
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
726 |
action.action_type = "ask"
|
727 |
action.content = "Can you provide more details about the patient's symptoms and history?"
|
728 |
-
action.reasoning =
|
|
|
|
|
729 |
|
730 |
-
if
|
731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
action.action_type = "diagnose"
|
733 |
-
action.content =
|
|
|
|
|
|
|
|
|
734 |
action.reasoning = "Budget constraint: insufficient funds for additional testing"
|
735 |
|
736 |
return action
|
@@ -741,13 +890,15 @@ class MaiDxOrchestrator:
|
|
741 |
return Action(
|
742 |
action_type="ask",
|
743 |
content="Could you please provide more information about the patient's current condition?",
|
744 |
-
reasoning=f"Fallback due to panel deliberation error: {str(e)}"
|
745 |
)
|
746 |
|
747 |
-
def _interact_with_gatekeeper(
|
|
|
|
|
748 |
"""Sends the panel's action to the Gatekeeper and returns its response."""
|
749 |
gatekeeper = self.agents[AgentRole.GATEKEEPER]
|
750 |
-
|
751 |
if action.action_type == "ask":
|
752 |
request = f"Question: {action.content}"
|
753 |
elif action.action_type == "test":
|
@@ -765,11 +916,13 @@ class MaiDxOrchestrator:
|
|
765 |
Request from Diagnostic Agent:
|
766 |
{request}
|
767 |
"""
|
768 |
-
|
769 |
response = gatekeeper.run(prompt)
|
770 |
return response
|
771 |
|
772 |
-
def _judge_diagnosis(
|
|
|
|
|
773 |
"""Uses the Judge agent to evaluate the final diagnosis."""
|
774 |
judge = self.agents[AgentRole.JUDGE]
|
775 |
prompt = f"""
|
@@ -778,18 +931,25 @@ class MaiDxOrchestrator:
|
|
778 |
Candidate Diagnosis: "{candidate_diagnosis}"
|
779 |
"""
|
780 |
response = judge.run(prompt)
|
781 |
-
|
782 |
# Simple parsing for demonstration; a more robust solution would use structured output.
|
783 |
try:
|
784 |
-
score = float(
|
|
|
|
|
785 |
reasoning = response.split("Justification:")[1].strip()
|
786 |
except (IndexError, ValueError):
|
787 |
score = 0.0
|
788 |
reasoning = "Could not parse judge's response."
|
789 |
-
|
790 |
return {"score": score, "reasoning": reasoning}
|
791 |
|
792 |
-
def run(
|
|
|
|
|
|
|
|
|
|
|
793 |
"""
|
794 |
Executes the full sequential diagnostic process.
|
795 |
|
@@ -802,90 +962,152 @@ class MaiDxOrchestrator:
|
|
802 |
DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
|
803 |
"""
|
804 |
start_time = time.time()
|
805 |
-
self.conversation.add(
|
806 |
-
|
|
|
|
|
|
|
807 |
# Add initial physician visit cost
|
808 |
self.cumulative_cost += self.physician_visit_cost
|
809 |
-
logger.info(
|
810 |
-
|
|
|
|
|
811 |
final_diagnosis = None
|
812 |
iteration_count = 0
|
813 |
-
|
814 |
for i in range(self.max_iterations):
|
815 |
iteration_count = i + 1
|
816 |
-
logger.info(
|
817 |
-
|
818 |
-
|
|
|
|
|
|
|
|
|
819 |
try:
|
820 |
# Panel deliberates to decide on the next action
|
821 |
action = self._run_panel_deliberation()
|
822 |
-
logger.info(
|
823 |
-
|
824 |
-
|
|
|
|
|
|
|
|
|
825 |
if action.action_type == "diagnose":
|
826 |
final_diagnosis = action.content
|
827 |
-
logger.info(
|
|
|
|
|
828 |
break
|
829 |
|
830 |
# Handle mode-specific constraints
|
831 |
-
if
|
832 |
-
|
|
|
|
|
|
|
|
|
|
|
833 |
continue
|
834 |
-
|
835 |
-
if
|
|
|
|
|
|
|
836 |
# Check if we can afford the tests
|
837 |
-
estimated_test_cost = self._estimate_cost(
|
838 |
-
|
839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
continue
|
841 |
|
842 |
# Interact with the Gatekeeper
|
843 |
-
response = self._interact_with_gatekeeper(
|
|
|
|
|
844 |
self.conversation.add("Gatekeeper", response)
|
845 |
-
|
846 |
# Update costs based on action type
|
847 |
if action.action_type == "test":
|
848 |
test_cost = self._estimate_cost(action.content)
|
849 |
self.cumulative_cost += test_cost
|
850 |
logger.info(f"Tests ordered: {action.content}")
|
851 |
-
logger.info(
|
|
|
|
|
852 |
elif action.action_type == "ask":
|
853 |
# Questions are part of the same visit until tests are ordered
|
854 |
logger.info(f"Questions asked: {action.content}")
|
855 |
-
logger.info(
|
|
|
|
|
856 |
|
857 |
# Check budget constraints for budgeted mode
|
858 |
-
if
|
859 |
-
|
|
|
|
|
|
|
|
|
|
|
860 |
# Use current differential diagnosis or make best guess
|
861 |
-
final_diagnosis =
|
|
|
|
|
|
|
|
|
862 |
break
|
863 |
-
|
864 |
except Exception as e:
|
865 |
-
logger.error(
|
|
|
|
|
866 |
# Continue to next iteration or break if critical error
|
867 |
continue
|
868 |
-
|
869 |
else:
|
870 |
# Max iterations reached without diagnosis
|
871 |
-
final_diagnosis =
|
872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
873 |
|
874 |
# Ensure we have a final diagnosis
|
875 |
if not final_diagnosis or final_diagnosis.strip() == "":
|
876 |
-
final_diagnosis =
|
877 |
-
|
|
|
|
|
878 |
# Calculate total time
|
879 |
total_time = time.time() - start_time
|
880 |
-
logger.info(
|
|
|
|
|
881 |
|
882 |
# Judge the final diagnosis
|
883 |
logger.info("Evaluating final diagnosis...")
|
884 |
try:
|
885 |
-
judgement = self._judge_diagnosis(
|
|
|
|
|
886 |
except Exception as e:
|
887 |
logger.error(f"Error in diagnosis evaluation: {e}")
|
888 |
-
judgement = {
|
|
|
|
|
|
|
889 |
|
890 |
# Create comprehensive result
|
891 |
result = DiagnosisResult(
|
@@ -895,39 +1117,49 @@ class MaiDxOrchestrator:
|
|
895 |
accuracy_reasoning=judgement["reasoning"],
|
896 |
total_cost=self.cumulative_cost,
|
897 |
iterations=iteration_count,
|
898 |
-
conversation_history=self.conversation.get_str()
|
899 |
)
|
900 |
-
|
901 |
-
logger.info(
|
902 |
logger.info(f" Final diagnosis: {final_diagnosis}")
|
903 |
logger.info(f" Ground truth: {ground_truth_diagnosis}")
|
904 |
logger.info(f" Accuracy score: {judgement['score']}/5.0")
|
905 |
logger.info(f" Total cost: ${self.cumulative_cost:,}")
|
906 |
logger.info(f" Iterations: {iteration_count}")
|
907 |
-
|
908 |
return result
|
909 |
-
|
910 |
-
def run_ensemble(
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
"""
|
912 |
Runs multiple independent diagnostic sessions and aggregates the results.
|
913 |
-
|
914 |
Args:
|
915 |
initial_case_info (str): The initial abstract of the case.
|
916 |
full_case_details (str): The complete case file for the Gatekeeper.
|
917 |
ground_truth_diagnosis (str): The correct final diagnosis for evaluation.
|
918 |
num_runs (int): Number of independent runs to perform.
|
919 |
-
|
920 |
Returns:
|
921 |
DiagnosisResult: Aggregated result from ensemble runs.
|
922 |
"""
|
923 |
-
logger.info(
|
924 |
-
|
|
|
|
|
925 |
ensemble_results = []
|
926 |
total_cost = 0
|
927 |
-
|
928 |
for run_id in range(num_runs):
|
929 |
-
logger.info(
|
930 |
-
|
|
|
|
|
931 |
# Create a fresh orchestrator instance for each run
|
932 |
run_orchestrator = MaiDxOrchestrator(
|
933 |
model_name=self.model_name,
|
@@ -935,38 +1167,52 @@ class MaiDxOrchestrator:
|
|
935 |
initial_budget=self.initial_budget,
|
936 |
mode="no_budget", # Use no_budget for ensemble runs
|
937 |
physician_visit_cost=self.physician_visit_cost,
|
938 |
-
enable_budget_tracking=False
|
939 |
)
|
940 |
-
|
941 |
# Run the diagnostic session
|
942 |
-
result = run_orchestrator.run(
|
|
|
|
|
|
|
|
|
943 |
ensemble_results.append(result)
|
944 |
total_cost += result.total_cost
|
945 |
-
|
946 |
-
logger.info(
|
947 |
-
|
|
|
|
|
948 |
# Aggregate results using consensus
|
949 |
-
final_diagnosis = self._aggregate_ensemble_diagnoses(
|
950 |
-
|
|
|
|
|
951 |
# Judge the aggregated diagnosis
|
952 |
-
judgement = self._judge_diagnosis(
|
953 |
-
|
|
|
|
|
954 |
# Calculate average metrics
|
955 |
-
avg_iterations = sum(
|
956 |
-
|
|
|
|
|
957 |
# Combine conversation histories
|
958 |
combined_history = "\n\n=== ENSEMBLE RESULTS ===\n"
|
959 |
for i, result in enumerate(ensemble_results):
|
960 |
combined_history += f"\n--- Run {i+1} ---\n"
|
961 |
-
combined_history +=
|
|
|
|
|
962 |
combined_history += f"Score: {result.accuracy_score}\n"
|
963 |
combined_history += f"Cost: ${result.total_cost:,}\n"
|
964 |
combined_history += f"Iterations: {result.iterations}\n"
|
965 |
-
|
966 |
-
combined_history +=
|
967 |
combined_history += f"Final Diagnosis: {final_diagnosis}\n"
|
968 |
combined_history += f"Reasoning: {judgement['reasoning']}\n"
|
969 |
-
|
970 |
ensemble_result = DiagnosisResult(
|
971 |
final_diagnosis=final_diagnosis,
|
972 |
ground_truth=ground_truth_diagnosis,
|
@@ -974,28 +1220,36 @@ class MaiDxOrchestrator:
|
|
974 |
accuracy_reasoning=judgement["reasoning"],
|
975 |
total_cost=total_cost, # Sum of all runs
|
976 |
iterations=int(avg_iterations),
|
977 |
-
conversation_history=combined_history
|
|
|
|
|
|
|
|
|
978 |
)
|
979 |
-
|
980 |
-
logger.info(f"Ensemble completed: {final_diagnosis} (Score: {judgement['score']})")
|
981 |
return ensemble_result
|
982 |
-
|
983 |
-
def _aggregate_ensemble_diagnoses(
|
|
|
|
|
984 |
"""Aggregates multiple diagnoses from ensemble runs."""
|
985 |
# Simple majority voting or use the most confident diagnosis
|
986 |
if not diagnoses:
|
987 |
return "No diagnosis available"
|
988 |
-
|
989 |
# Remove any empty or invalid diagnoses
|
990 |
-
valid_diagnoses = [
|
991 |
-
|
|
|
|
|
|
|
|
|
992 |
if not valid_diagnoses:
|
993 |
return diagnoses[0] if diagnoses else "No valid diagnosis"
|
994 |
-
|
995 |
# If all diagnoses are the same, return that
|
996 |
if len(set(valid_diagnoses)) == 1:
|
997 |
return valid_diagnoses[0]
|
998 |
-
|
999 |
# Use an aggregator agent to select the best diagnosis
|
1000 |
try:
|
1001 |
aggregator_prompt = f"""
|
@@ -1008,32 +1262,35 @@ class MaiDxOrchestrator:
|
|
1008 |
Provide the single best diagnosis that represents the medical consensus.
|
1009 |
Consider clinical accuracy, specificity, and completeness.
|
1010 |
"""
|
1011 |
-
|
1012 |
aggregator = Agent(
|
1013 |
agent_name="Ensemble Aggregator",
|
1014 |
system_prompt=aggregator_prompt,
|
1015 |
model_name=self.model_name,
|
1016 |
max_loops=1,
|
1017 |
-
print_on=True # Enable printing for aggregator agent
|
1018 |
)
|
1019 |
-
|
1020 |
return aggregator.run(aggregator_prompt).strip()
|
1021 |
-
|
1022 |
except Exception as e:
|
1023 |
logger.error(f"Error in ensemble aggregation: {e}")
|
1024 |
# Fallback to most common diagnosis
|
1025 |
from collections import Counter
|
|
|
1026 |
return Counter(valid_diagnoses).most_common(1)[0][0]
|
1027 |
-
|
1028 |
@classmethod
|
1029 |
-
def create_variant(
|
|
|
|
|
1030 |
"""
|
1031 |
Factory method to create different MAI-DxO variants as described in the paper.
|
1032 |
-
|
1033 |
Args:
|
1034 |
variant (str): One of 'instant', 'question_only', 'budgeted', 'no_budget', 'ensemble'
|
1035 |
**kwargs: Additional parameters for the orchestrator
|
1036 |
-
|
1037 |
Returns:
|
1038 |
MaiDxOrchestrator: Configured orchestrator instance
|
1039 |
"""
|
@@ -1041,49 +1298,55 @@ class MaiDxOrchestrator:
|
|
1041 |
"instant": {
|
1042 |
"mode": "instant",
|
1043 |
"max_iterations": 1,
|
1044 |
-
"enable_budget_tracking": False
|
1045 |
},
|
1046 |
"question_only": {
|
1047 |
"mode": "question_only",
|
1048 |
"max_iterations": 10,
|
1049 |
-
"enable_budget_tracking": False
|
1050 |
},
|
1051 |
"budgeted": {
|
1052 |
"mode": "budgeted",
|
1053 |
"max_iterations": 10,
|
1054 |
"enable_budget_tracking": True,
|
1055 |
-
"initial_budget": kwargs.get("budget", 5000)
|
1056 |
},
|
1057 |
"no_budget": {
|
1058 |
"mode": "no_budget",
|
1059 |
"max_iterations": 10,
|
1060 |
-
"enable_budget_tracking": False
|
1061 |
},
|
1062 |
"ensemble": {
|
1063 |
"mode": "no_budget",
|
1064 |
"max_iterations": 10,
|
1065 |
-
"enable_budget_tracking": False
|
1066 |
-
}
|
1067 |
}
|
1068 |
-
|
1069 |
if variant not in variant_configs:
|
1070 |
-
raise ValueError(
|
1071 |
-
|
|
|
|
|
1072 |
config = variant_configs[variant]
|
1073 |
config.update(kwargs) # Allow overrides
|
1074 |
-
|
1075 |
return cls(**config)
|
1076 |
|
1077 |
|
1078 |
-
def run_mai_dxo_demo(
|
|
|
|
|
|
|
|
|
1079 |
"""
|
1080 |
Convenience function to run a quick demonstration of MAI-DxO variants.
|
1081 |
-
|
1082 |
Args:
|
1083 |
case_info (str): Initial case information. Uses default if None.
|
1084 |
case_details (str): Full case details. Uses default if None.
|
1085 |
ground_truth (str): Ground truth diagnosis. Uses default if None.
|
1086 |
-
|
1087 |
Returns:
|
1088 |
Dict[str, DiagnosisResult]: Results from different MAI-DxO variants
|
1089 |
"""
|
@@ -1093,7 +1356,7 @@ def run_mai_dxo_demo(case_info: str = None, case_details: str = None, ground_tru
|
|
1093 |
"A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
|
1094 |
"and bleeding. Symptoms did not abate with antimicrobial therapy."
|
1095 |
)
|
1096 |
-
|
1097 |
if not case_details:
|
1098 |
case_details = """
|
1099 |
Patient: 29-year-old female.
|
@@ -1107,31 +1370,39 @@ def run_mai_dxo_demo(case_info: str = None, case_details: str = None, ground_tru
|
|
1107 |
Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
|
1108 |
Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
|
1109 |
"""
|
1110 |
-
|
1111 |
if not ground_truth:
|
1112 |
ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
|
1113 |
-
|
1114 |
results = {}
|
1115 |
-
|
1116 |
# Test key variants
|
1117 |
variants = ["no_budget", "budgeted", "question_only"]
|
1118 |
-
|
1119 |
for variant in variants:
|
1120 |
try:
|
1121 |
logger.info(f"Running MAI-DxO variant: {variant}")
|
1122 |
-
|
1123 |
if variant == "budgeted":
|
1124 |
-
orchestrator = MaiDxOrchestrator.create_variant(
|
|
|
|
|
|
|
|
|
1125 |
else:
|
1126 |
-
orchestrator = MaiDxOrchestrator.create_variant(
|
1127 |
-
|
1128 |
-
|
|
|
|
|
|
|
|
|
1129 |
results[variant] = result
|
1130 |
-
|
1131 |
except Exception as e:
|
1132 |
logger.error(f"Error running variant {variant}: {e}")
|
1133 |
results[variant] = None
|
1134 |
-
|
1135 |
return results
|
1136 |
|
1137 |
|
@@ -1141,7 +1412,7 @@ if __name__ == "__main__":
|
|
1141 |
"A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
|
1142 |
"and bleeding. Symptoms did not abate with antimicrobial therapy."
|
1143 |
)
|
1144 |
-
|
1145 |
full_case = """
|
1146 |
Patient: 29-year-old female.
|
1147 |
History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
|
@@ -1155,55 +1426,65 @@ if __name__ == "__main__":
|
|
1155 |
Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
|
1156 |
Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
|
1157 |
"""
|
1158 |
-
|
1159 |
ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
|
1160 |
-
|
1161 |
# --- Demonstrate Different MAI-DxO Variants ---
|
1162 |
try:
|
1163 |
-
print("\n" + "="*80)
|
1164 |
-
print(
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
|
|
|
|
|
|
|
|
1168 |
# Test different variants as described in the paper
|
1169 |
variants_to_test = [
|
1170 |
-
(
|
|
|
|
|
|
|
1171 |
("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
|
1172 |
-
(
|
|
|
|
|
|
|
1173 |
]
|
1174 |
-
|
1175 |
results = {}
|
1176 |
-
|
1177 |
for variant_name, description in variants_to_test:
|
1178 |
print(f"\n{'='*60}")
|
1179 |
print(f"Testing Variant: {variant_name.upper()}")
|
1180 |
print(f"Description: {description}")
|
1181 |
-
print(
|
1182 |
-
|
1183 |
# Create the variant
|
1184 |
if variant_name == "budgeted":
|
1185 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1186 |
-
variant_name,
|
1187 |
budget=3000,
|
1188 |
model_name="gemini/gemini-2.5-flash",
|
1189 |
-
max_iterations=5
|
1190 |
)
|
1191 |
else:
|
1192 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1193 |
variant_name,
|
1194 |
-
model_name="gemini/gemini-2.5-flash",
|
1195 |
-
max_iterations=5
|
1196 |
)
|
1197 |
-
|
1198 |
# Run the diagnostic process
|
1199 |
result = orchestrator.run(
|
1200 |
initial_case_info=initial_info,
|
1201 |
full_case_details=full_case,
|
1202 |
-
ground_truth_diagnosis=ground_truth
|
1203 |
)
|
1204 |
-
|
1205 |
results[variant_name] = result
|
1206 |
-
|
1207 |
# Display results
|
1208 |
print(f"\nπ Final Diagnosis: {result.final_diagnosis}")
|
1209 |
print(f"π― Ground Truth: {result.ground_truth}")
|
@@ -1212,50 +1493,72 @@ if __name__ == "__main__":
|
|
1212 |
print(f"π° Total Cost: ${result.total_cost:,}")
|
1213 |
print(f"π Iterations: {result.iterations}")
|
1214 |
print(f"β±οΈ Mode: {orchestrator.mode}")
|
1215 |
-
|
1216 |
# Demonstrate ensemble approach
|
1217 |
print(f"\n{'='*60}")
|
1218 |
print("Testing Variant: ENSEMBLE")
|
1219 |
-
print(
|
1220 |
-
|
1221 |
-
|
|
|
|
|
1222 |
ensemble_orchestrator = MaiDxOrchestrator.create_variant(
|
1223 |
"ensemble",
|
1224 |
model_name="gemini/gemini-2.5-flash",
|
1225 |
-
max_iterations=3 # Shorter iterations for ensemble
|
1226 |
)
|
1227 |
-
|
1228 |
ensemble_result = ensemble_orchestrator.run_ensemble(
|
1229 |
initial_case_info=initial_info,
|
1230 |
full_case_details=full_case,
|
1231 |
ground_truth_diagnosis=ground_truth,
|
1232 |
-
num_runs=2 # Reduced for demo
|
1233 |
)
|
1234 |
-
|
1235 |
results["ensemble"] = ensemble_result
|
1236 |
-
|
1237 |
-
print(
|
|
|
|
|
1238 |
print(f"π― Ground Truth: {ensemble_result.ground_truth}")
|
1239 |
-
print(
|
1240 |
-
|
1241 |
-
|
|
|
|
|
|
|
|
|
1242 |
# --- Summary Comparison ---
|
1243 |
print(f"\n{'='*80}")
|
1244 |
print(" RESULTS SUMMARY")
|
1245 |
-
print(
|
1246 |
-
print(
|
1247 |
-
|
1248 |
-
|
|
|
|
|
1249 |
for variant_name, result in results.items():
|
1250 |
-
match_status =
|
1251 |
-
|
1252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1253 |
print(f"\n{'='*80}")
|
1254 |
-
print(
|
1255 |
-
|
1256 |
-
|
|
|
|
|
|
|
|
|
1257 |
|
1258 |
except Exception as e:
|
1259 |
-
logger.exception(
|
|
|
|
|
1260 |
print(f"\nβ Error occurred: {e}")
|
1261 |
-
print("Please check your model configuration and API keys.")
|
|
|
1 |
"""
|
2 |
MAI Diagnostic Orchestrator (MAI-DxO)
|
3 |
|
4 |
+
This script provides a complete implementation of the "Sequential Diagnosis with Language Models"
|
5 |
+
paper, using the `swarms` framework. It simulates a virtual panel of physician-agents to perform
|
6 |
iterative medical diagnosis with cost-effectiveness optimization.
|
7 |
|
8 |
Based on the paper: "Sequential Diagnosis with Language Models"
|
|
|
20 |
# Standard MAI-DxO usage
|
21 |
orchestrator = MaiDxOrchestrator(model_name="gemini/gemini-2.5-flash")
|
22 |
result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
|
23 |
+
|
24 |
# Budget-constrained variant
|
25 |
budgeted_orchestrator = MaiDxOrchestrator.create_variant("budgeted", budget=5000)
|
26 |
+
|
27 |
# Ensemble approach
|
28 |
ensemble_result = orchestrator.run_ensemble(initial_case_info, full_case_details, ground_truth)
|
29 |
"""
|
30 |
|
31 |
+
# Enable debug mode if environment variable is set
|
32 |
+
import os
|
33 |
import json
|
34 |
import sys
|
35 |
import time
|
36 |
from dataclasses import dataclass
|
37 |
from enum import Enum
|
38 |
+
from typing import Any, Dict, List, Union, Literal
|
39 |
|
40 |
from loguru import logger
|
41 |
from pydantic import BaseModel, Field
|
|
|
49 |
sys.stdout,
|
50 |
level="INFO",
|
51 |
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
52 |
+
colorize=True,
|
53 |
)
|
54 |
|
55 |
+
|
|
|
56 |
if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
|
57 |
logger.add(
|
58 |
"logs/maidx_debug_{time:YYYY-MM-DD}.log",
|
59 |
level="DEBUG",
|
60 |
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
61 |
rotation="1 day",
|
62 |
+
retention="3 days",
|
63 |
+
)
|
64 |
+
logger.info(
|
65 |
+
"π Debug logging enabled - logs will be written to logs/ directory"
|
66 |
)
|
|
|
67 |
|
68 |
# File handler for persistent logging (optional - uncomment if needed)
|
69 |
# logger.add(
|
70 |
# "logs/mai_dxo_{time:YYYY-MM-DD}.log",
|
71 |
# rotation="1 day",
|
72 |
+
# retention="7 days",
|
73 |
# level="DEBUG",
|
74 |
# format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
75 |
# compression="zip"
|
|
|
77 |
|
78 |
# --- Data Structures and Enums ---
|
79 |
|
80 |
+
|
81 |
class AgentRole(Enum):
|
82 |
"""Enumeration of roles for the virtual physician panel."""
|
83 |
+
|
84 |
HYPOTHESIS = "Dr. Hypothesis"
|
85 |
TEST_CHOOSER = "Dr. Test-Chooser"
|
86 |
CHALLENGER = "Dr. Challenger"
|
|
|
90 |
GATEKEEPER = "Gatekeeper"
|
91 |
JUDGE = "Judge"
|
92 |
|
93 |
+
|
94 |
@dataclass
|
95 |
class DiagnosisResult:
|
96 |
"""Stores the final result of a diagnostic session."""
|
97 |
+
|
98 |
final_diagnosis: str
|
99 |
ground_truth: str
|
100 |
accuracy_score: float
|
|
|
103 |
iterations: int
|
104 |
conversation_history: str
|
105 |
|
106 |
+
|
107 |
class Action(BaseModel):
|
108 |
"""Pydantic model for a structured action decided by the consensus agent."""
|
109 |
+
|
110 |
+
action_type: Literal["ask", "test", "diagnose"] = Field(
|
111 |
+
..., description="The type of action to perform."
|
112 |
+
)
|
113 |
+
content: Union[str, List[str]] = Field(
|
114 |
+
...,
|
115 |
+
description="The content of the action (question, test name, or diagnosis).",
|
116 |
+
)
|
117 |
+
reasoning: str = Field(
|
118 |
+
..., description="The reasoning behind choosing this action."
|
119 |
+
)
|
120 |
+
|
121 |
|
122 |
# --- Main Orchestrator Class ---
|
123 |
|
124 |
+
|
125 |
class MaiDxOrchestrator:
|
126 |
"""
|
127 |
Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
|
128 |
This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
|
129 |
evaluates the final diagnosis, and tracks costs.
|
130 |
"""
|
131 |
+
|
132 |
def __init__(
|
133 |
self,
|
134 |
model_name: str = "gemini/gemini-2.5-flash",
|
|
|
155 |
self.mode = mode
|
156 |
self.physician_visit_cost = physician_visit_cost
|
157 |
self.enable_budget_tracking = enable_budget_tracking
|
158 |
+
|
159 |
self.cumulative_cost = 0
|
160 |
self.differential_diagnosis = "Not yet formulated."
|
161 |
self.conversation = Conversation(
|
162 |
+
time_enabled=True, autosave=False, save_enabled=False
|
|
|
|
|
163 |
)
|
164 |
+
|
165 |
# Enhanced cost model based on the paper's methodology
|
166 |
self.test_cost_db = {
|
167 |
"default": 150,
|
|
|
194 |
}
|
195 |
|
196 |
self._init_agents()
|
197 |
+
logger.info(
|
198 |
+
f"π₯ MAI Diagnostic Orchestrator initialized successfully in '{mode}' mode with budget ${initial_budget:,}"
|
199 |
+
)
|
200 |
|
201 |
def _init_agents(self):
|
202 |
"""Initializes all required agents with their specific roles and prompts."""
|
|
|
206 |
system_prompt=self._get_prompt_for_role(role),
|
207 |
model_name=self.model_name,
|
208 |
max_loops=1,
|
209 |
+
output_type=(
|
210 |
+
"json" if role == AgentRole.CONSENSUS else "str"
|
211 |
+
),
|
212 |
print_on=True, # Enable printing for all agents to see outputs
|
213 |
+
)
|
214 |
+
for role in AgentRole
|
215 |
}
|
216 |
+
logger.info(
|
217 |
+
f"π₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
|
218 |
+
)
|
219 |
|
220 |
def _get_prompt_for_role(self, role: AgentRole) -> str:
|
221 |
"""Returns the system prompt for a given agent role."""
|
222 |
prompts = {
|
223 |
+
AgentRole.HYPOTHESIS: (
|
224 |
+
"""
|
225 |
You are Dr. Hypothesis, a specialist in maintaining differential diagnoses. Your role is critical to the diagnostic process.
|
226 |
|
227 |
CORE RESPONSIBILITIES:
|
|
|
246 |
- Evidence that contradicts or challenges each hypothesis
|
247 |
|
248 |
Remember: Your differential drives the entire diagnostic process. Be thorough, evidence-based, and adaptive.
|
249 |
+
"""
|
250 |
+
),
|
251 |
+
AgentRole.TEST_CHOOSER: (
|
252 |
+
"""
|
253 |
You are Dr. Test-Chooser, a specialist in diagnostic test selection and information theory.
|
254 |
|
255 |
CORE RESPONSIBILITIES:
|
|
|
278 |
- How results will change management decisions
|
279 |
|
280 |
Focus on tests that will most efficiently narrow the differential diagnosis.
|
281 |
+
"""
|
282 |
+
),
|
283 |
+
AgentRole.CHALLENGER: (
|
284 |
+
"""
|
285 |
You are Dr. Challenger, the critical thinking specialist and devil's advocate.
|
286 |
|
287 |
CORE RESPONSIBILITIES:
|
|
|
312 |
- Red flags or concerning patterns that need attention
|
313 |
|
314 |
Be constructively critical - your role is to strengthen diagnostic accuracy through rigorous challenge.
|
315 |
+
"""
|
316 |
+
),
|
317 |
+
AgentRole.STEWARDSHIP: (
|
318 |
+
"""
|
319 |
You are Dr. Stewardship, the resource optimization and cost-effectiveness specialist.
|
320 |
|
321 |
CORE RESPONSIBILITIES:
|
|
|
351 |
- Cumulative cost considerations
|
352 |
|
353 |
Your goal: Maximum diagnostic accuracy at minimum necessary cost.
|
354 |
+
"""
|
355 |
+
),
|
356 |
+
AgentRole.CHECKLIST: (
|
357 |
+
"""
|
358 |
You are Dr. Checklist, the quality assurance and consistency specialist.
|
359 |
|
360 |
CORE RESPONSIBILITIES:
|
|
|
385 |
- Process improvement suggestions
|
386 |
|
387 |
Keep your feedback concise but comprehensive. Flag any issues that could compromise diagnostic quality.
|
388 |
+
"""
|
389 |
+
),
|
390 |
+
AgentRole.CONSENSUS: (
|
391 |
+
"""
|
392 |
You are the Consensus Coordinator, responsible for synthesizing the virtual panel's expertise into a single, optimal decision.
|
393 |
|
394 |
CORE RESPONSIBILITIES:
|
|
|
422 |
For action_type "diagnose": content should be the complete, specific final diagnosis
|
423 |
|
424 |
Make the decision that best advances accurate, cost-effective diagnosis.
|
425 |
+
"""
|
426 |
+
),
|
427 |
+
AgentRole.GATEKEEPER: (
|
428 |
+
"""
|
429 |
You are the Gatekeeper, the clinical information oracle with complete access to the patient case file.
|
430 |
|
431 |
CORE RESPONSIBILITIES:
|
|
|
462 |
- Professional medical terminology
|
463 |
|
464 |
Your role is crucial: provide complete, accurate clinical information while maintaining the challenge of the diagnostic process.
|
465 |
+
"""
|
466 |
+
),
|
467 |
+
AgentRole.JUDGE: (
|
468 |
+
"""
|
469 |
You are the Judge, the diagnostic accuracy evaluation specialist.
|
470 |
|
471 |
CORE RESPONSIBILITIES:
|
|
|
521 |
|
522 |
Maintain high standards while recognizing legitimate diagnostic variability in medical practice.
|
523 |
"""
|
524 |
+
),
|
525 |
}
|
526 |
return prompts[role]
|
527 |
+
|
528 |
def _parse_json_response(self, response: str) -> Dict[str, Any]:
|
529 |
"""Safely parses a JSON string, returning a dictionary."""
|
530 |
try:
|
|
|
540 |
start_idx += len(start_marker)
|
541 |
end_idx = response.find(end_marker, start_idx)
|
542 |
if end_idx != -1:
|
543 |
+
json_content = response[
|
544 |
+
start_idx:end_idx
|
545 |
+
].strip()
|
546 |
return json.loads(json_content)
|
547 |
+
|
548 |
# Try to find JSON-like content in the response
|
549 |
+
lines = response.split("\n")
|
550 |
json_lines = []
|
551 |
in_json = False
|
552 |
brace_count = 0
|
553 |
+
|
554 |
for line in lines:
|
555 |
stripped_line = line.strip()
|
556 |
+
if stripped_line.startswith("{") and not in_json:
|
557 |
in_json = True
|
558 |
json_lines = [line] # Start fresh
|
559 |
+
brace_count = line.count("{") - line.count(
|
560 |
+
"}"
|
561 |
+
)
|
562 |
elif in_json:
|
563 |
json_lines.append(line)
|
564 |
+
brace_count += line.count("{") - line.count(
|
565 |
+
"}"
|
566 |
+
)
|
567 |
+
if (
|
568 |
+
brace_count <= 0
|
569 |
+
): # Balanced braces, end of JSON
|
570 |
break
|
571 |
+
|
572 |
if json_lines and in_json:
|
573 |
+
json_content = "\n".join(json_lines)
|
574 |
return json.loads(json_content)
|
575 |
+
|
576 |
# Try to extract JSON from text that might contain other content
|
577 |
import re
|
578 |
+
|
579 |
# Look for JSON pattern in the text
|
580 |
+
json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
|
581 |
+
matches = re.findall(
|
582 |
+
json_pattern, response, re.DOTALL
|
583 |
+
)
|
584 |
+
|
585 |
for match in matches:
|
586 |
try:
|
587 |
return json.loads(match)
|
588 |
except json.JSONDecodeError:
|
589 |
continue
|
590 |
+
|
591 |
# Direct parsing attempt as fallback
|
592 |
return json.loads(response)
|
593 |
+
|
594 |
+
except (
|
595 |
+
json.JSONDecodeError,
|
596 |
+
IndexError,
|
597 |
+
AttributeError,
|
598 |
+
) as e:
|
599 |
logger.error(f"Failed to parse JSON response. Error: {e}")
|
600 |
+
logger.debug(
|
601 |
+
f"Response content: {response[:500]}..."
|
602 |
+
) # Log first 500 chars
|
603 |
# Fallback to a default action if parsing fails
|
604 |
return {
|
605 |
"action_type": "ask",
|
606 |
+
"content": (
|
607 |
+
"Could you please clarify the next best step? The previous analysis was inconclusive."
|
608 |
+
),
|
609 |
+
"reasoning": "Fallback due to parsing error.",
|
610 |
}
|
611 |
+
|
612 |
def _estimate_cost(self, tests: Union[List[str], str]) -> int:
|
613 |
"""Estimates the cost of diagnostic tests."""
|
614 |
if isinstance(tests, str):
|
615 |
tests = [tests]
|
616 |
+
|
617 |
cost = 0
|
618 |
for test in tests:
|
619 |
test_lower = test.lower().strip()
|
620 |
+
|
621 |
# Enhanced cost matching with multiple strategies
|
622 |
cost_found = False
|
623 |
+
|
624 |
# Strategy 1: Exact match
|
625 |
if test_lower in self.test_cost_db:
|
626 |
cost += self.test_cost_db[test_lower]
|
627 |
cost_found = True
|
628 |
continue
|
629 |
+
|
630 |
# Strategy 2: Partial match (find best matching key)
|
631 |
best_match = None
|
632 |
best_match_length = 0
|
|
|
635 |
if len(cost_key) > best_match_length:
|
636 |
best_match = cost_key
|
637 |
best_match_length = len(cost_key)
|
638 |
+
|
639 |
if best_match:
|
640 |
cost += self.test_cost_db[best_match]
|
641 |
cost_found = True
|
642 |
continue
|
643 |
+
|
644 |
# Strategy 3: Keyword-based matching
|
645 |
+
if any(
|
646 |
+
keyword in test_lower
|
647 |
+
for keyword in ["biopsy", "tissue"]
|
648 |
+
):
|
649 |
+
cost += self.test_cost_db.get("biopsy", 800)
|
650 |
cost_found = True
|
651 |
+
elif any(
|
652 |
+
keyword in test_lower
|
653 |
+
for keyword in ["mri", "magnetic"]
|
654 |
+
):
|
655 |
+
cost += self.test_cost_db.get("mri", 1500)
|
656 |
cost_found = True
|
657 |
+
elif any(
|
658 |
+
keyword in test_lower
|
659 |
+
for keyword in ["ct", "computed tomography"]
|
660 |
+
):
|
661 |
+
cost += self.test_cost_db.get("ct scan", 1200)
|
662 |
cost_found = True
|
663 |
+
elif any(
|
664 |
+
keyword in test_lower
|
665 |
+
for keyword in ["xray", "x-ray", "radiograph"]
|
666 |
+
):
|
667 |
+
cost += self.test_cost_db.get("chest x-ray", 200)
|
668 |
cost_found = True
|
669 |
+
elif any(
|
670 |
+
keyword in test_lower
|
671 |
+
for keyword in ["blood", "serum", "plasma"]
|
672 |
+
):
|
673 |
cost += 100 # Basic blood test cost
|
674 |
cost_found = True
|
675 |
+
elif any(
|
676 |
+
keyword in test_lower
|
677 |
+
for keyword in ["culture", "sensitivity"]
|
678 |
+
):
|
679 |
+
cost += self.test_cost_db.get("culture", 150)
|
680 |
cost_found = True
|
681 |
+
elif any(
|
682 |
+
keyword in test_lower
|
683 |
+
for keyword in ["immunohistochemistry", "ihc"]
|
684 |
+
):
|
685 |
+
cost += self.test_cost_db.get(
|
686 |
+
"immunohistochemistry", 400
|
687 |
+
)
|
688 |
cost_found = True
|
689 |
+
|
690 |
# Strategy 4: Default cost for unknown tests
|
691 |
if not cost_found:
|
692 |
+
cost += self.test_cost_db["default"]
|
693 |
+
logger.debug(
|
694 |
+
f"Using default cost for unknown test: {test}"
|
695 |
+
)
|
696 |
+
|
697 |
return cost
|
698 |
|
699 |
def _run_panel_deliberation(self) -> Action:
|
700 |
"""Orchestrates one round of debate among the virtual panel to decide the next action."""
|
701 |
+
logger.info(
|
702 |
+
"π©Ί Virtual medical panel deliberation commenced - analyzing patient case"
|
703 |
+
)
|
704 |
+
logger.debug(
|
705 |
+
"Panel members: Dr. Hypothesis, Dr. Test-Chooser, Dr. Challenger, Dr. Stewardship, Dr. Checklist"
|
706 |
+
)
|
707 |
panel_conversation = Conversation(
|
708 |
+
time_enabled=True, autosave=False, save_enabled=False
|
|
|
|
|
709 |
)
|
710 |
+
|
711 |
# Prepare comprehensive panel context
|
712 |
remaining_budget = self.initial_budget - self.cumulative_cost
|
713 |
+
budget_status = (
|
714 |
+
"EXCEEDED"
|
715 |
+
if remaining_budget < 0
|
716 |
+
else f"${remaining_budget:,}"
|
717 |
+
)
|
718 |
+
|
719 |
panel_context = f"""
|
720 |
DIAGNOSTIC CASE STATUS - ROUND {len(self.conversation.return_history_as_string().split('Gatekeeper:')) - 1}
|
721 |
|
|
|
740 |
# For instant mode, skip deliberation and go straight to diagnosis
|
741 |
action_dict = {
|
742 |
"action_type": "diagnose",
|
743 |
+
"content": (
|
744 |
+
self.differential_diagnosis.split("\n")[0]
|
745 |
+
if "\n" in self.differential_diagnosis
|
746 |
+
else self.differential_diagnosis
|
747 |
+
),
|
748 |
+
"reasoning": (
|
749 |
+
"Instant diagnosis mode - providing immediate assessment based on initial presentation"
|
750 |
+
),
|
751 |
}
|
752 |
return Action(**action_dict)
|
753 |
+
|
754 |
if self.mode == "question_only":
|
755 |
# For question-only mode, prevent test ordering
|
756 |
panel_context += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY ask patient questions, not order diagnostic tests."
|
|
|
759 |
# Sequential expert deliberation with enhanced methodology
|
760 |
try:
|
761 |
# Dr. Hypothesis - Differential diagnosis and probability assessment
|
762 |
+
logger.info(
|
763 |
+
"π§ Dr. Hypothesis analyzing differential diagnosis..."
|
764 |
+
)
|
765 |
+
hypothesis = self.agents[AgentRole.HYPOTHESIS].run(
|
766 |
+
panel_conversation.get_str()
|
767 |
+
)
|
768 |
+
self.differential_diagnosis = (
|
769 |
+
hypothesis # Update main state
|
770 |
+
)
|
771 |
+
panel_conversation.add(
|
772 |
+
self.agents[AgentRole.HYPOTHESIS].agent_name,
|
773 |
+
hypothesis,
|
774 |
+
)
|
775 |
+
|
776 |
# Dr. Test-Chooser - Information value optimization
|
777 |
+
logger.info(
|
778 |
+
"π¬ Dr. Test-Chooser selecting optimal tests..."
|
779 |
+
)
|
780 |
+
test_choices = self.agents[AgentRole.TEST_CHOOSER].run(
|
781 |
+
panel_conversation.get_str()
|
782 |
+
)
|
783 |
+
panel_conversation.add(
|
784 |
+
self.agents[AgentRole.TEST_CHOOSER].agent_name,
|
785 |
+
test_choices,
|
786 |
+
)
|
787 |
|
788 |
# Dr. Challenger - Bias identification and alternative hypotheses
|
789 |
+
logger.info(
|
790 |
+
"π€ Dr. Challenger challenging assumptions..."
|
791 |
+
)
|
792 |
+
challenges = self.agents[AgentRole.CHALLENGER].run(
|
793 |
+
panel_conversation.get_str()
|
794 |
+
)
|
795 |
+
panel_conversation.add(
|
796 |
+
self.agents[AgentRole.CHALLENGER].agent_name,
|
797 |
+
challenges,
|
798 |
+
)
|
799 |
|
800 |
# Dr. Stewardship - Cost-effectiveness analysis
|
801 |
+
logger.info(
|
802 |
+
"π° Dr. Stewardship evaluating cost-effectiveness..."
|
803 |
+
)
|
804 |
stewardship_context = panel_conversation.get_str()
|
805 |
if self.enable_budget_tracking:
|
806 |
stewardship_context += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${self.cumulative_cost}, Remaining: ${remaining_budget}"
|
807 |
+
stewardship_rec = self.agents[AgentRole.STEWARDSHIP].run(
|
808 |
+
stewardship_context
|
809 |
+
)
|
810 |
+
panel_conversation.add(
|
811 |
+
self.agents[AgentRole.STEWARDSHIP].agent_name,
|
812 |
+
stewardship_rec,
|
813 |
+
)
|
814 |
+
|
815 |
# Dr. Checklist - Quality assurance
|
816 |
+
logger.info(
|
817 |
+
"β
Dr. Checklist performing quality control..."
|
818 |
+
)
|
819 |
+
checklist_rep = self.agents[AgentRole.CHECKLIST].run(
|
820 |
+
panel_conversation.get_str()
|
821 |
+
)
|
822 |
+
panel_conversation.add(
|
823 |
+
self.agents[AgentRole.CHECKLIST].agent_name,
|
824 |
+
checklist_rep,
|
825 |
+
)
|
826 |
+
|
827 |
# Consensus Coordinator - Final decision synthesis
|
828 |
+
logger.info(
|
829 |
+
"π€ Consensus Coordinator synthesizing panel decision..."
|
830 |
+
)
|
831 |
consensus_context = panel_conversation.get_str()
|
832 |
+
|
833 |
# Add mode-specific constraints to consensus
|
834 |
if self.mode == "budgeted" and remaining_budget <= 0:
|
835 |
consensus_context += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
|
836 |
+
|
837 |
+
consensus_response = self.agents[AgentRole.CONSENSUS].run(
|
838 |
+
consensus_context
|
839 |
+
)
|
840 |
+
logger.debug(
|
841 |
+
f"Raw consensus response: {consensus_response}"
|
842 |
+
)
|
843 |
+
|
844 |
# Extract the actual text content from agent response
|
845 |
+
if hasattr(consensus_response, "content"):
|
846 |
response_text = consensus_response.content
|
847 |
elif isinstance(consensus_response, str):
|
848 |
response_text = consensus_response
|
849 |
else:
|
850 |
response_text = str(consensus_response)
|
851 |
+
|
852 |
action_dict = self._parse_json_response(response_text)
|
853 |
|
854 |
# Validate action based on mode constraints
|
855 |
action = Action(**action_dict)
|
856 |
+
if (
|
857 |
+
self.mode == "question_only"
|
858 |
+
and action.action_type == "test"
|
859 |
+
):
|
860 |
+
logger.warning(
|
861 |
+
"Test ordering attempted in question-only mode, converting to ask action"
|
862 |
+
)
|
863 |
action.action_type = "ask"
|
864 |
action.content = "Can you provide more details about the patient's symptoms and history?"
|
865 |
+
action.reasoning = (
|
866 |
+
"Mode constraint: question-only mode active"
|
867 |
+
)
|
868 |
|
869 |
+
if (
|
870 |
+
self.mode == "budgeted"
|
871 |
+
and action.action_type == "test"
|
872 |
+
and remaining_budget <= 0
|
873 |
+
):
|
874 |
+
logger.warning(
|
875 |
+
"Test ordering attempted with insufficient budget, converting to diagnose action"
|
876 |
+
)
|
877 |
action.action_type = "diagnose"
|
878 |
+
action.content = (
|
879 |
+
self.differential_diagnosis.split("\n")[0]
|
880 |
+
if "\n" in self.differential_diagnosis
|
881 |
+
else self.differential_diagnosis
|
882 |
+
)
|
883 |
action.reasoning = "Budget constraint: insufficient funds for additional testing"
|
884 |
|
885 |
return action
|
|
|
890 |
return Action(
|
891 |
action_type="ask",
|
892 |
content="Could you please provide more information about the patient's current condition?",
|
893 |
+
reasoning=f"Fallback due to panel deliberation error: {str(e)}",
|
894 |
)
|
895 |
|
896 |
+
def _interact_with_gatekeeper(
|
897 |
+
self, action: Action, full_case_details: str
|
898 |
+
) -> str:
|
899 |
"""Sends the panel's action to the Gatekeeper and returns its response."""
|
900 |
gatekeeper = self.agents[AgentRole.GATEKEEPER]
|
901 |
+
|
902 |
if action.action_type == "ask":
|
903 |
request = f"Question: {action.content}"
|
904 |
elif action.action_type == "test":
|
|
|
916 |
Request from Diagnostic Agent:
|
917 |
{request}
|
918 |
"""
|
919 |
+
|
920 |
response = gatekeeper.run(prompt)
|
921 |
return response
|
922 |
|
923 |
+
def _judge_diagnosis(
|
924 |
+
self, candidate_diagnosis: str, ground_truth: str
|
925 |
+
) -> Dict[str, Any]:
|
926 |
"""Uses the Judge agent to evaluate the final diagnosis."""
|
927 |
judge = self.agents[AgentRole.JUDGE]
|
928 |
prompt = f"""
|
|
|
931 |
Candidate Diagnosis: "{candidate_diagnosis}"
|
932 |
"""
|
933 |
response = judge.run(prompt)
|
934 |
+
|
935 |
# Simple parsing for demonstration; a more robust solution would use structured output.
|
936 |
try:
|
937 |
+
score = float(
|
938 |
+
response.split("Score:")[1].split("/")[0].strip()
|
939 |
+
)
|
940 |
reasoning = response.split("Justification:")[1].strip()
|
941 |
except (IndexError, ValueError):
|
942 |
score = 0.0
|
943 |
reasoning = "Could not parse judge's response."
|
944 |
+
|
945 |
return {"score": score, "reasoning": reasoning}
|
946 |
|
947 |
+
def run(
|
948 |
+
self,
|
949 |
+
initial_case_info: str,
|
950 |
+
full_case_details: str,
|
951 |
+
ground_truth_diagnosis: str,
|
952 |
+
) -> DiagnosisResult:
|
953 |
"""
|
954 |
Executes the full sequential diagnostic process.
|
955 |
|
|
|
962 |
DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
|
963 |
"""
|
964 |
start_time = time.time()
|
965 |
+
self.conversation.add(
|
966 |
+
"Gatekeeper",
|
967 |
+
f"Initial Case Information: {initial_case_info}",
|
968 |
+
)
|
969 |
+
|
970 |
# Add initial physician visit cost
|
971 |
self.cumulative_cost += self.physician_visit_cost
|
972 |
+
logger.info(
|
973 |
+
f"Initial physician visit cost: ${self.physician_visit_cost}"
|
974 |
+
)
|
975 |
+
|
976 |
final_diagnosis = None
|
977 |
iteration_count = 0
|
978 |
+
|
979 |
for i in range(self.max_iterations):
|
980 |
iteration_count = i + 1
|
981 |
+
logger.info(
|
982 |
+
f"--- Starting Diagnostic Loop {iteration_count}/{self.max_iterations} ---"
|
983 |
+
)
|
984 |
+
logger.info(
|
985 |
+
f"Current cost: ${self.cumulative_cost:,} | Remaining budget: ${self.initial_budget - self.cumulative_cost:,}"
|
986 |
+
)
|
987 |
+
|
988 |
try:
|
989 |
# Panel deliberates to decide on the next action
|
990 |
action = self._run_panel_deliberation()
|
991 |
+
logger.info(
|
992 |
+
f"βοΈ Panel decision: {action.action_type.upper()} -> {action.content}"
|
993 |
+
)
|
994 |
+
logger.info(
|
995 |
+
f"π Medical reasoning: {action.reasoning}"
|
996 |
+
)
|
997 |
+
|
998 |
if action.action_type == "diagnose":
|
999 |
final_diagnosis = action.content
|
1000 |
+
logger.info(
|
1001 |
+
f"Final diagnosis proposed: {final_diagnosis}"
|
1002 |
+
)
|
1003 |
break
|
1004 |
|
1005 |
# Handle mode-specific constraints
|
1006 |
+
if (
|
1007 |
+
self.mode == "question_only"
|
1008 |
+
and action.action_type == "test"
|
1009 |
+
):
|
1010 |
+
logger.warning(
|
1011 |
+
"Test ordering blocked in question-only mode"
|
1012 |
+
)
|
1013 |
continue
|
1014 |
+
|
1015 |
+
if (
|
1016 |
+
self.mode == "budgeted"
|
1017 |
+
and action.action_type == "test"
|
1018 |
+
):
|
1019 |
# Check if we can afford the tests
|
1020 |
+
estimated_test_cost = self._estimate_cost(
|
1021 |
+
action.content
|
1022 |
+
)
|
1023 |
+
if (
|
1024 |
+
self.cumulative_cost + estimated_test_cost
|
1025 |
+
> self.initial_budget
|
1026 |
+
):
|
1027 |
+
logger.warning(
|
1028 |
+
f"Test cost ${estimated_test_cost} would exceed budget. Skipping tests."
|
1029 |
+
)
|
1030 |
continue
|
1031 |
|
1032 |
# Interact with the Gatekeeper
|
1033 |
+
response = self._interact_with_gatekeeper(
|
1034 |
+
action, full_case_details
|
1035 |
+
)
|
1036 |
self.conversation.add("Gatekeeper", response)
|
1037 |
+
|
1038 |
# Update costs based on action type
|
1039 |
if action.action_type == "test":
|
1040 |
test_cost = self._estimate_cost(action.content)
|
1041 |
self.cumulative_cost += test_cost
|
1042 |
logger.info(f"Tests ordered: {action.content}")
|
1043 |
+
logger.info(
|
1044 |
+
f"Test cost: ${test_cost:,} | Cumulative cost: ${self.cumulative_cost:,}"
|
1045 |
+
)
|
1046 |
elif action.action_type == "ask":
|
1047 |
# Questions are part of the same visit until tests are ordered
|
1048 |
logger.info(f"Questions asked: {action.content}")
|
1049 |
+
logger.info(
|
1050 |
+
"No additional cost for questions in same visit"
|
1051 |
+
)
|
1052 |
|
1053 |
# Check budget constraints for budgeted mode
|
1054 |
+
if (
|
1055 |
+
self.mode == "budgeted"
|
1056 |
+
and self.cumulative_cost >= self.initial_budget
|
1057 |
+
):
|
1058 |
+
logger.warning(
|
1059 |
+
"Budget limit reached. Forcing final diagnosis."
|
1060 |
+
)
|
1061 |
# Use current differential diagnosis or make best guess
|
1062 |
+
final_diagnosis = (
|
1063 |
+
self.differential_diagnosis.split("\n")[0]
|
1064 |
+
if "\n" in self.differential_diagnosis
|
1065 |
+
else "Diagnosis not reached within budget constraints."
|
1066 |
+
)
|
1067 |
break
|
1068 |
+
|
1069 |
except Exception as e:
|
1070 |
+
logger.error(
|
1071 |
+
f"Error in diagnostic loop {iteration_count}: {e}"
|
1072 |
+
)
|
1073 |
# Continue to next iteration or break if critical error
|
1074 |
continue
|
1075 |
+
|
1076 |
else:
|
1077 |
# Max iterations reached without diagnosis
|
1078 |
+
final_diagnosis = (
|
1079 |
+
self.differential_diagnosis.split("\n")[0]
|
1080 |
+
if "\n" in self.differential_diagnosis
|
1081 |
+
else "Diagnosis not reached within maximum iterations."
|
1082 |
+
)
|
1083 |
+
logger.warning(
|
1084 |
+
f"Max iterations ({self.max_iterations}) reached. Using best available diagnosis."
|
1085 |
+
)
|
1086 |
|
1087 |
# Ensure we have a final diagnosis
|
1088 |
if not final_diagnosis or final_diagnosis.strip() == "":
|
1089 |
+
final_diagnosis = (
|
1090 |
+
"Unable to determine diagnosis within constraints."
|
1091 |
+
)
|
1092 |
+
|
1093 |
# Calculate total time
|
1094 |
total_time = time.time() - start_time
|
1095 |
+
logger.info(
|
1096 |
+
f"Diagnostic session completed in {total_time:.2f} seconds"
|
1097 |
+
)
|
1098 |
|
1099 |
# Judge the final diagnosis
|
1100 |
logger.info("Evaluating final diagnosis...")
|
1101 |
try:
|
1102 |
+
judgement = self._judge_diagnosis(
|
1103 |
+
final_diagnosis, ground_truth_diagnosis
|
1104 |
+
)
|
1105 |
except Exception as e:
|
1106 |
logger.error(f"Error in diagnosis evaluation: {e}")
|
1107 |
+
judgement = {
|
1108 |
+
"score": 0.0,
|
1109 |
+
"reasoning": f"Evaluation error: {str(e)}",
|
1110 |
+
}
|
1111 |
|
1112 |
# Create comprehensive result
|
1113 |
result = DiagnosisResult(
|
|
|
1117 |
accuracy_reasoning=judgement["reasoning"],
|
1118 |
total_cost=self.cumulative_cost,
|
1119 |
iterations=iteration_count,
|
1120 |
+
conversation_history=self.conversation.get_str(),
|
1121 |
)
|
1122 |
+
|
1123 |
+
logger.info("Diagnostic process completed:")
|
1124 |
logger.info(f" Final diagnosis: {final_diagnosis}")
|
1125 |
logger.info(f" Ground truth: {ground_truth_diagnosis}")
|
1126 |
logger.info(f" Accuracy score: {judgement['score']}/5.0")
|
1127 |
logger.info(f" Total cost: ${self.cumulative_cost:,}")
|
1128 |
logger.info(f" Iterations: {iteration_count}")
|
1129 |
+
|
1130 |
return result
|
1131 |
+
|
1132 |
+
def run_ensemble(
|
1133 |
+
self,
|
1134 |
+
initial_case_info: str,
|
1135 |
+
full_case_details: str,
|
1136 |
+
ground_truth_diagnosis: str,
|
1137 |
+
num_runs: int = 3,
|
1138 |
+
) -> DiagnosisResult:
|
1139 |
"""
|
1140 |
Runs multiple independent diagnostic sessions and aggregates the results.
|
1141 |
+
|
1142 |
Args:
|
1143 |
initial_case_info (str): The initial abstract of the case.
|
1144 |
full_case_details (str): The complete case file for the Gatekeeper.
|
1145 |
ground_truth_diagnosis (str): The correct final diagnosis for evaluation.
|
1146 |
num_runs (int): Number of independent runs to perform.
|
1147 |
+
|
1148 |
Returns:
|
1149 |
DiagnosisResult: Aggregated result from ensemble runs.
|
1150 |
"""
|
1151 |
+
logger.info(
|
1152 |
+
f"Starting ensemble run with {num_runs} independent sessions"
|
1153 |
+
)
|
1154 |
+
|
1155 |
ensemble_results = []
|
1156 |
total_cost = 0
|
1157 |
+
|
1158 |
for run_id in range(num_runs):
|
1159 |
+
logger.info(
|
1160 |
+
f"=== Ensemble Run {run_id + 1}/{num_runs} ==="
|
1161 |
+
)
|
1162 |
+
|
1163 |
# Create a fresh orchestrator instance for each run
|
1164 |
run_orchestrator = MaiDxOrchestrator(
|
1165 |
model_name=self.model_name,
|
|
|
1167 |
initial_budget=self.initial_budget,
|
1168 |
mode="no_budget", # Use no_budget for ensemble runs
|
1169 |
physician_visit_cost=self.physician_visit_cost,
|
1170 |
+
enable_budget_tracking=False,
|
1171 |
)
|
1172 |
+
|
1173 |
# Run the diagnostic session
|
1174 |
+
result = run_orchestrator.run(
|
1175 |
+
initial_case_info,
|
1176 |
+
full_case_details,
|
1177 |
+
ground_truth_diagnosis,
|
1178 |
+
)
|
1179 |
ensemble_results.append(result)
|
1180 |
total_cost += result.total_cost
|
1181 |
+
|
1182 |
+
logger.info(
|
1183 |
+
f"Run {run_id + 1} completed: {result.final_diagnosis} (Score: {result.accuracy_score})"
|
1184 |
+
)
|
1185 |
+
|
1186 |
# Aggregate results using consensus
|
1187 |
+
final_diagnosis = self._aggregate_ensemble_diagnoses(
|
1188 |
+
[r.final_diagnosis for r in ensemble_results]
|
1189 |
+
)
|
1190 |
+
|
1191 |
# Judge the aggregated diagnosis
|
1192 |
+
judgement = self._judge_diagnosis(
|
1193 |
+
final_diagnosis, ground_truth_diagnosis
|
1194 |
+
)
|
1195 |
+
|
1196 |
# Calculate average metrics
|
1197 |
+
avg_iterations = sum(
|
1198 |
+
r.iterations for r in ensemble_results
|
1199 |
+
) / len(ensemble_results)
|
1200 |
+
|
1201 |
# Combine conversation histories
|
1202 |
combined_history = "\n\n=== ENSEMBLE RESULTS ===\n"
|
1203 |
for i, result in enumerate(ensemble_results):
|
1204 |
combined_history += f"\n--- Run {i+1} ---\n"
|
1205 |
+
combined_history += (
|
1206 |
+
f"Diagnosis: {result.final_diagnosis}\n"
|
1207 |
+
)
|
1208 |
combined_history += f"Score: {result.accuracy_score}\n"
|
1209 |
combined_history += f"Cost: ${result.total_cost:,}\n"
|
1210 |
combined_history += f"Iterations: {result.iterations}\n"
|
1211 |
+
|
1212 |
+
combined_history += "\n--- Aggregated Result ---\n"
|
1213 |
combined_history += f"Final Diagnosis: {final_diagnosis}\n"
|
1214 |
combined_history += f"Reasoning: {judgement['reasoning']}\n"
|
1215 |
+
|
1216 |
ensemble_result = DiagnosisResult(
|
1217 |
final_diagnosis=final_diagnosis,
|
1218 |
ground_truth=ground_truth_diagnosis,
|
|
|
1220 |
accuracy_reasoning=judgement["reasoning"],
|
1221 |
total_cost=total_cost, # Sum of all runs
|
1222 |
iterations=int(avg_iterations),
|
1223 |
+
conversation_history=combined_history,
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
logger.info(
|
1227 |
+
f"Ensemble completed: {final_diagnosis} (Score: {judgement['score']})"
|
1228 |
)
|
|
|
|
|
1229 |
return ensemble_result
|
1230 |
+
|
1231 |
+
def _aggregate_ensemble_diagnoses(
|
1232 |
+
self, diagnoses: List[str]
|
1233 |
+
) -> str:
|
1234 |
"""Aggregates multiple diagnoses from ensemble runs."""
|
1235 |
# Simple majority voting or use the most confident diagnosis
|
1236 |
if not diagnoses:
|
1237 |
return "No diagnosis available"
|
1238 |
+
|
1239 |
# Remove any empty or invalid diagnoses
|
1240 |
+
valid_diagnoses = [
|
1241 |
+
d
|
1242 |
+
for d in diagnoses
|
1243 |
+
if d and d.strip() and "not reached" not in d.lower()
|
1244 |
+
]
|
1245 |
+
|
1246 |
if not valid_diagnoses:
|
1247 |
return diagnoses[0] if diagnoses else "No valid diagnosis"
|
1248 |
+
|
1249 |
# If all diagnoses are the same, return that
|
1250 |
if len(set(valid_diagnoses)) == 1:
|
1251 |
return valid_diagnoses[0]
|
1252 |
+
|
1253 |
# Use an aggregator agent to select the best diagnosis
|
1254 |
try:
|
1255 |
aggregator_prompt = f"""
|
|
|
1262 |
Provide the single best diagnosis that represents the medical consensus.
|
1263 |
Consider clinical accuracy, specificity, and completeness.
|
1264 |
"""
|
1265 |
+
|
1266 |
aggregator = Agent(
|
1267 |
agent_name="Ensemble Aggregator",
|
1268 |
system_prompt=aggregator_prompt,
|
1269 |
model_name=self.model_name,
|
1270 |
max_loops=1,
|
1271 |
+
print_on=True, # Enable printing for aggregator agent
|
1272 |
)
|
1273 |
+
|
1274 |
return aggregator.run(aggregator_prompt).strip()
|
1275 |
+
|
1276 |
except Exception as e:
|
1277 |
logger.error(f"Error in ensemble aggregation: {e}")
|
1278 |
# Fallback to most common diagnosis
|
1279 |
from collections import Counter
|
1280 |
+
|
1281 |
return Counter(valid_diagnoses).most_common(1)[0][0]
|
1282 |
+
|
1283 |
@classmethod
|
1284 |
+
def create_variant(
|
1285 |
+
cls, variant: str, **kwargs
|
1286 |
+
) -> "MaiDxOrchestrator":
|
1287 |
"""
|
1288 |
Factory method to create different MAI-DxO variants as described in the paper.
|
1289 |
+
|
1290 |
Args:
|
1291 |
variant (str): One of 'instant', 'question_only', 'budgeted', 'no_budget', 'ensemble'
|
1292 |
**kwargs: Additional parameters for the orchestrator
|
1293 |
+
|
1294 |
Returns:
|
1295 |
MaiDxOrchestrator: Configured orchestrator instance
|
1296 |
"""
|
|
|
1298 |
"instant": {
|
1299 |
"mode": "instant",
|
1300 |
"max_iterations": 1,
|
1301 |
+
"enable_budget_tracking": False,
|
1302 |
},
|
1303 |
"question_only": {
|
1304 |
"mode": "question_only",
|
1305 |
"max_iterations": 10,
|
1306 |
+
"enable_budget_tracking": False,
|
1307 |
},
|
1308 |
"budgeted": {
|
1309 |
"mode": "budgeted",
|
1310 |
"max_iterations": 10,
|
1311 |
"enable_budget_tracking": True,
|
1312 |
+
"initial_budget": kwargs.get("budget", 5000),
|
1313 |
},
|
1314 |
"no_budget": {
|
1315 |
"mode": "no_budget",
|
1316 |
"max_iterations": 10,
|
1317 |
+
"enable_budget_tracking": False,
|
1318 |
},
|
1319 |
"ensemble": {
|
1320 |
"mode": "no_budget",
|
1321 |
"max_iterations": 10,
|
1322 |
+
"enable_budget_tracking": False,
|
1323 |
+
},
|
1324 |
}
|
1325 |
+
|
1326 |
if variant not in variant_configs:
|
1327 |
+
raise ValueError(
|
1328 |
+
f"Unknown variant: {variant}. Choose from: {list(variant_configs.keys())}"
|
1329 |
+
)
|
1330 |
+
|
1331 |
config = variant_configs[variant]
|
1332 |
config.update(kwargs) # Allow overrides
|
1333 |
+
|
1334 |
return cls(**config)
|
1335 |
|
1336 |
|
1337 |
+
def run_mai_dxo_demo(
|
1338 |
+
case_info: str = None,
|
1339 |
+
case_details: str = None,
|
1340 |
+
ground_truth: str = None,
|
1341 |
+
) -> Dict[str, DiagnosisResult]:
|
1342 |
"""
|
1343 |
Convenience function to run a quick demonstration of MAI-DxO variants.
|
1344 |
+
|
1345 |
Args:
|
1346 |
case_info (str): Initial case information. Uses default if None.
|
1347 |
case_details (str): Full case details. Uses default if None.
|
1348 |
ground_truth (str): Ground truth diagnosis. Uses default if None.
|
1349 |
+
|
1350 |
Returns:
|
1351 |
Dict[str, DiagnosisResult]: Results from different MAI-DxO variants
|
1352 |
"""
|
|
|
1356 |
"A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
|
1357 |
"and bleeding. Symptoms did not abate with antimicrobial therapy."
|
1358 |
)
|
1359 |
+
|
1360 |
if not case_details:
|
1361 |
case_details = """
|
1362 |
Patient: 29-year-old female.
|
|
|
1370 |
Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
|
1371 |
Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
|
1372 |
"""
|
1373 |
+
|
1374 |
if not ground_truth:
|
1375 |
ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
|
1376 |
+
|
1377 |
results = {}
|
1378 |
+
|
1379 |
# Test key variants
|
1380 |
variants = ["no_budget", "budgeted", "question_only"]
|
1381 |
+
|
1382 |
for variant in variants:
|
1383 |
try:
|
1384 |
logger.info(f"Running MAI-DxO variant: {variant}")
|
1385 |
+
|
1386 |
if variant == "budgeted":
|
1387 |
+
orchestrator = MaiDxOrchestrator.create_variant(
|
1388 |
+
variant,
|
1389 |
+
budget=3000,
|
1390 |
+
model_name="gemini/gemini-2.5-flash",
|
1391 |
+
)
|
1392 |
else:
|
1393 |
+
orchestrator = MaiDxOrchestrator.create_variant(
|
1394 |
+
variant, model_name="gemini/gemini-2.5-flash"
|
1395 |
+
)
|
1396 |
+
|
1397 |
+
result = orchestrator.run(
|
1398 |
+
case_info, case_details, ground_truth
|
1399 |
+
)
|
1400 |
results[variant] = result
|
1401 |
+
|
1402 |
except Exception as e:
|
1403 |
logger.error(f"Error running variant {variant}: {e}")
|
1404 |
results[variant] = None
|
1405 |
+
|
1406 |
return results
|
1407 |
|
1408 |
|
|
|
1412 |
"A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
|
1413 |
"and bleeding. Symptoms did not abate with antimicrobial therapy."
|
1414 |
)
|
1415 |
+
|
1416 |
full_case = """
|
1417 |
Patient: 29-year-old female.
|
1418 |
History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
|
|
|
1426 |
Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
|
1427 |
Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
|
1428 |
"""
|
1429 |
+
|
1430 |
ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
|
1431 |
+
|
1432 |
# --- Demonstrate Different MAI-DxO Variants ---
|
1433 |
try:
|
1434 |
+
print("\n" + "=" * 80)
|
1435 |
+
print(
|
1436 |
+
" MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK"
|
1437 |
+
)
|
1438 |
+
print(
|
1439 |
+
" Implementation based on the NEJM Research Paper"
|
1440 |
+
)
|
1441 |
+
print("=" * 80)
|
1442 |
+
|
1443 |
# Test different variants as described in the paper
|
1444 |
variants_to_test = [
|
1445 |
+
(
|
1446 |
+
"no_budget",
|
1447 |
+
"Standard MAI-DxO with no budget constraints",
|
1448 |
+
),
|
1449 |
("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
|
1450 |
+
(
|
1451 |
+
"question_only",
|
1452 |
+
"Question-only variant (no diagnostic tests)",
|
1453 |
+
),
|
1454 |
]
|
1455 |
+
|
1456 |
results = {}
|
1457 |
+
|
1458 |
for variant_name, description in variants_to_test:
|
1459 |
print(f"\n{'='*60}")
|
1460 |
print(f"Testing Variant: {variant_name.upper()}")
|
1461 |
print(f"Description: {description}")
|
1462 |
+
print("=" * 60)
|
1463 |
+
|
1464 |
# Create the variant
|
1465 |
if variant_name == "budgeted":
|
1466 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1467 |
+
variant_name,
|
1468 |
budget=3000,
|
1469 |
model_name="gemini/gemini-2.5-flash",
|
1470 |
+
max_iterations=5,
|
1471 |
)
|
1472 |
else:
|
1473 |
orchestrator = MaiDxOrchestrator.create_variant(
|
1474 |
variant_name,
|
1475 |
+
model_name="gemini/gemini-2.5-flash",
|
1476 |
+
max_iterations=5,
|
1477 |
)
|
1478 |
+
|
1479 |
# Run the diagnostic process
|
1480 |
result = orchestrator.run(
|
1481 |
initial_case_info=initial_info,
|
1482 |
full_case_details=full_case,
|
1483 |
+
ground_truth_diagnosis=ground_truth,
|
1484 |
)
|
1485 |
+
|
1486 |
results[variant_name] = result
|
1487 |
+
|
1488 |
# Display results
|
1489 |
print(f"\nπ Final Diagnosis: {result.final_diagnosis}")
|
1490 |
print(f"π― Ground Truth: {result.ground_truth}")
|
|
|
1493 |
print(f"π° Total Cost: ${result.total_cost:,}")
|
1494 |
print(f"π Iterations: {result.iterations}")
|
1495 |
print(f"β±οΈ Mode: {orchestrator.mode}")
|
1496 |
+
|
1497 |
# Demonstrate ensemble approach
|
1498 |
print(f"\n{'='*60}")
|
1499 |
print("Testing Variant: ENSEMBLE")
|
1500 |
+
print(
|
1501 |
+
"Description: Multiple independent runs with consensus aggregation"
|
1502 |
+
)
|
1503 |
+
print("=" * 60)
|
1504 |
+
|
1505 |
ensemble_orchestrator = MaiDxOrchestrator.create_variant(
|
1506 |
"ensemble",
|
1507 |
model_name="gemini/gemini-2.5-flash",
|
1508 |
+
max_iterations=3, # Shorter iterations for ensemble
|
1509 |
)
|
1510 |
+
|
1511 |
ensemble_result = ensemble_orchestrator.run_ensemble(
|
1512 |
initial_case_info=initial_info,
|
1513 |
full_case_details=full_case,
|
1514 |
ground_truth_diagnosis=ground_truth,
|
1515 |
+
num_runs=2, # Reduced for demo
|
1516 |
)
|
1517 |
+
|
1518 |
results["ensemble"] = ensemble_result
|
1519 |
+
|
1520 |
+
print(
|
1521 |
+
f"\nπ Ensemble Diagnosis: {ensemble_result.final_diagnosis}"
|
1522 |
+
)
|
1523 |
print(f"π― Ground Truth: {ensemble_result.ground_truth}")
|
1524 |
+
print(
|
1525 |
+
f"β Ensemble Score: {ensemble_result.accuracy_score}/5.0"
|
1526 |
+
)
|
1527 |
+
print(
|
1528 |
+
f"π° Total Ensemble Cost: ${ensemble_result.total_cost:,}"
|
1529 |
+
)
|
1530 |
+
|
1531 |
# --- Summary Comparison ---
|
1532 |
print(f"\n{'='*80}")
|
1533 |
print(" RESULTS SUMMARY")
|
1534 |
+
print("=" * 80)
|
1535 |
+
print(
|
1536 |
+
f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}"
|
1537 |
+
)
|
1538 |
+
print("-" * 80)
|
1539 |
+
|
1540 |
for variant_name, result in results.items():
|
1541 |
+
match_status = (
|
1542 |
+
"β Match"
|
1543 |
+
if result.accuracy_score >= 4.0
|
1544 |
+
else "β No Match"
|
1545 |
+
)
|
1546 |
+
print(
|
1547 |
+
f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}"
|
1548 |
+
)
|
1549 |
+
|
1550 |
print(f"\n{'='*80}")
|
1551 |
+
print(
|
1552 |
+
"Implementation successfully demonstrates the MAI-DxO framework"
|
1553 |
+
)
|
1554 |
+
print(
|
1555 |
+
"as described in 'Sequential Diagnosis with Language Models' paper"
|
1556 |
+
)
|
1557 |
+
print("=" * 80)
|
1558 |
|
1559 |
except Exception as e:
|
1560 |
+
logger.exception(
|
1561 |
+
f"An error occurred during the diagnostic session: {e}"
|
1562 |
+
)
|
1563 |
print(f"\nβ Error occurred: {e}")
|
1564 |
+
print("Please check your model configuration and API keys.")
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
|
2 |
-
zetascale
|
3 |
swarms
|
|
|
|
1 |
+
loguru
|
|
|
2 |
swarms
|
3 |
+
pydantic
|