Kye Gomez commited on
Commit
69f6ed5
Β·
1 Parent(s): 5239f6e

__init__ and requirements.txr

Browse files
Files changed (3) hide show
  1. mai_dx/__init__.py +3 -0
  2. mai_dx/main.py +601 -298
  3. requirements.txt +2 -2
mai_dx/__init__.py CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from mai_dx.main import MaiDxOrchestrator, run_mai_dxo_demo
2
+
3
+ __all__ = ["MaiDxOrchestrator", "run_mai_dxo_demo"]
mai_dx/main.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
  MAI Diagnostic Orchestrator (MAI-DxO)
3
 
4
- This script provides a complete implementation of the "Sequential Diagnosis with Language Models"
5
- paper, using the `swarms` framework. It simulates a virtual panel of physician-agents to perform
6
  iterative medical diagnosis with cost-effectiveness optimization.
7
 
8
  Based on the paper: "Sequential Diagnosis with Language Models"
@@ -20,20 +20,22 @@ Example Usage:
20
  # Standard MAI-DxO usage
21
  orchestrator = MaiDxOrchestrator(model_name="gemini/gemini-2.5-flash")
22
  result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
23
-
24
  # Budget-constrained variant
25
  budgeted_orchestrator = MaiDxOrchestrator.create_variant("budgeted", budget=5000)
26
-
27
  # Ensemble approach
28
  ensemble_result = orchestrator.run_ensemble(initial_case_info, full_case_details, ground_truth)
29
  """
30
 
 
 
31
  import json
32
  import sys
33
  import time
34
  from dataclasses import dataclass
35
  from enum import Enum
36
- from typing import Any, Dict, List, Optional, Union, Literal
37
 
38
  from loguru import logger
39
  from pydantic import BaseModel, Field
@@ -47,26 +49,27 @@ logger.add(
47
  sys.stdout,
48
  level="INFO",
49
  format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
50
- colorize=True
51
  )
52
 
53
- # Enable debug mode if environment variable is set
54
- import os
55
  if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
56
  logger.add(
57
  "logs/maidx_debug_{time:YYYY-MM-DD}.log",
58
  level="DEBUG",
59
  format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
60
  rotation="1 day",
61
- retention="3 days"
 
 
 
62
  )
63
- logger.info("πŸ› Debug logging enabled - logs will be written to logs/ directory")
64
 
65
  # File handler for persistent logging (optional - uncomment if needed)
66
  # logger.add(
67
  # "logs/mai_dxo_{time:YYYY-MM-DD}.log",
68
  # rotation="1 day",
69
- # retention="7 days",
70
  # level="DEBUG",
71
  # format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
72
  # compression="zip"
@@ -74,8 +77,10 @@ if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
74
 
75
  # --- Data Structures and Enums ---
76
 
 
77
  class AgentRole(Enum):
78
  """Enumeration of roles for the virtual physician panel."""
 
79
  HYPOTHESIS = "Dr. Hypothesis"
80
  TEST_CHOOSER = "Dr. Test-Chooser"
81
  CHALLENGER = "Dr. Challenger"
@@ -85,9 +90,11 @@ class AgentRole(Enum):
85
  GATEKEEPER = "Gatekeeper"
86
  JUDGE = "Judge"
87
 
 
88
  @dataclass
89
  class DiagnosisResult:
90
  """Stores the final result of a diagnostic session."""
 
91
  final_diagnosis: str
92
  ground_truth: str
93
  accuracy_score: float
@@ -96,20 +103,32 @@ class DiagnosisResult:
96
  iterations: int
97
  conversation_history: str
98
 
 
99
  class Action(BaseModel):
100
  """Pydantic model for a structured action decided by the consensus agent."""
101
- action_type: Literal["ask", "test", "diagnose"] = Field(..., description="The type of action to perform.")
102
- content: Union[str, List[str]] = Field(..., description="The content of the action (question, test name, or diagnosis).")
103
- reasoning: str = Field(..., description="The reasoning behind choosing this action.")
 
 
 
 
 
 
 
 
 
104
 
105
  # --- Main Orchestrator Class ---
106
 
 
107
  class MaiDxOrchestrator:
108
  """
109
  Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
110
  This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
111
  evaluates the final diagnosis, and tracks costs.
112
  """
 
113
  def __init__(
114
  self,
115
  model_name: str = "gemini/gemini-2.5-flash",
@@ -136,15 +155,13 @@ class MaiDxOrchestrator:
136
  self.mode = mode
137
  self.physician_visit_cost = physician_visit_cost
138
  self.enable_budget_tracking = enable_budget_tracking
139
-
140
  self.cumulative_cost = 0
141
  self.differential_diagnosis = "Not yet formulated."
142
  self.conversation = Conversation(
143
- time_enabled=True,
144
- autosave=False,
145
- save_enabled=False
146
  )
147
-
148
  # Enhanced cost model based on the paper's methodology
149
  self.test_cost_db = {
150
  "default": 150,
@@ -177,7 +194,9 @@ class MaiDxOrchestrator:
177
  }
178
 
179
  self._init_agents()
180
- logger.info(f"πŸ₯ MAI Diagnostic Orchestrator initialized successfully in '{mode}' mode with budget ${initial_budget:,}")
 
 
181
 
182
  def _init_agents(self):
183
  """Initializes all required agents with their specific roles and prompts."""
@@ -187,16 +206,22 @@ class MaiDxOrchestrator:
187
  system_prompt=self._get_prompt_for_role(role),
188
  model_name=self.model_name,
189
  max_loops=1,
190
- output_type="json" if role == AgentRole.CONSENSUS else "str",
 
 
191
  print_on=True, # Enable printing for all agents to see outputs
192
- ) for role in AgentRole
 
193
  }
194
- logger.info(f"πŸ‘₯ {len(self.agents)} virtual physician agents initialized and ready for consultation")
 
 
195
 
196
  def _get_prompt_for_role(self, role: AgentRole) -> str:
197
  """Returns the system prompt for a given agent role."""
198
  prompts = {
199
- AgentRole.HYPOTHESIS: """
 
200
  You are Dr. Hypothesis, a specialist in maintaining differential diagnoses. Your role is critical to the diagnostic process.
201
 
202
  CORE RESPONSIBILITIES:
@@ -221,9 +246,10 @@ class MaiDxOrchestrator:
221
  - Evidence that contradicts or challenges each hypothesis
222
 
223
  Remember: Your differential drives the entire diagnostic process. Be thorough, evidence-based, and adaptive.
224
- """,
225
-
226
- AgentRole.TEST_CHOOSER: """
 
227
  You are Dr. Test-Chooser, a specialist in diagnostic test selection and information theory.
228
 
229
  CORE RESPONSIBILITIES:
@@ -252,9 +278,10 @@ class MaiDxOrchestrator:
252
  - How results will change management decisions
253
 
254
  Focus on tests that will most efficiently narrow the differential diagnosis.
255
- """,
256
-
257
- AgentRole.CHALLENGER: """
 
258
  You are Dr. Challenger, the critical thinking specialist and devil's advocate.
259
 
260
  CORE RESPONSIBILITIES:
@@ -285,9 +312,10 @@ class MaiDxOrchestrator:
285
  - Red flags or concerning patterns that need attention
286
 
287
  Be constructively critical - your role is to strengthen diagnostic accuracy through rigorous challenge.
288
- """,
289
-
290
- AgentRole.STEWARDSHIP: """
 
291
  You are Dr. Stewardship, the resource optimization and cost-effectiveness specialist.
292
 
293
  CORE RESPONSIBILITIES:
@@ -323,9 +351,10 @@ class MaiDxOrchestrator:
323
  - Cumulative cost considerations
324
 
325
  Your goal: Maximum diagnostic accuracy at minimum necessary cost.
326
- """,
327
-
328
- AgentRole.CHECKLIST: """
 
329
  You are Dr. Checklist, the quality assurance and consistency specialist.
330
 
331
  CORE RESPONSIBILITIES:
@@ -356,9 +385,10 @@ class MaiDxOrchestrator:
356
  - Process improvement suggestions
357
 
358
  Keep your feedback concise but comprehensive. Flag any issues that could compromise diagnostic quality.
359
- """,
360
-
361
- AgentRole.CONSENSUS: """
 
362
  You are the Consensus Coordinator, responsible for synthesizing the virtual panel's expertise into a single, optimal decision.
363
 
364
  CORE RESPONSIBILITIES:
@@ -392,9 +422,10 @@ class MaiDxOrchestrator:
392
  For action_type "diagnose": content should be the complete, specific final diagnosis
393
 
394
  Make the decision that best advances accurate, cost-effective diagnosis.
395
- """,
396
-
397
- AgentRole.GATEKEEPER: """
 
398
  You are the Gatekeeper, the clinical information oracle with complete access to the patient case file.
399
 
400
  CORE RESPONSIBILITIES:
@@ -431,9 +462,10 @@ class MaiDxOrchestrator:
431
  - Professional medical terminology
432
 
433
  Your role is crucial: provide complete, accurate clinical information while maintaining the challenge of the diagnostic process.
434
- """,
435
-
436
- AgentRole.JUDGE: """
 
437
  You are the Judge, the diagnostic accuracy evaluation specialist.
438
 
439
  CORE RESPONSIBILITIES:
@@ -489,9 +521,10 @@ class MaiDxOrchestrator:
489
 
490
  Maintain high standards while recognizing legitimate diagnostic variability in medical practice.
491
  """
 
492
  }
493
  return prompts[role]
494
-
495
  def _parse_json_response(self, response: str) -> Dict[str, Any]:
496
  """Safely parses a JSON string, returning a dictionary."""
497
  try:
@@ -507,74 +540,93 @@ class MaiDxOrchestrator:
507
  start_idx += len(start_marker)
508
  end_idx = response.find(end_marker, start_idx)
509
  if end_idx != -1:
510
- json_content = response[start_idx:end_idx].strip()
 
 
511
  return json.loads(json_content)
512
-
513
  # Try to find JSON-like content in the response
514
- lines = response.split('\n')
515
  json_lines = []
516
  in_json = False
517
  brace_count = 0
518
-
519
  for line in lines:
520
  stripped_line = line.strip()
521
- if stripped_line.startswith('{') and not in_json:
522
  in_json = True
523
  json_lines = [line] # Start fresh
524
- brace_count = line.count('{') - line.count('}')
 
 
525
  elif in_json:
526
  json_lines.append(line)
527
- brace_count += line.count('{') - line.count('}')
528
- if brace_count <= 0: # Balanced braces, end of JSON
 
 
 
 
529
  break
530
-
531
  if json_lines and in_json:
532
- json_content = '\n'.join(json_lines)
533
  return json.loads(json_content)
534
-
535
  # Try to extract JSON from text that might contain other content
536
  import re
 
537
  # Look for JSON pattern in the text
538
- json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
539
- matches = re.findall(json_pattern, response, re.DOTALL)
540
-
 
 
541
  for match in matches:
542
  try:
543
  return json.loads(match)
544
  except json.JSONDecodeError:
545
  continue
546
-
547
  # Direct parsing attempt as fallback
548
  return json.loads(response)
549
-
550
- except (json.JSONDecodeError, IndexError, AttributeError) as e:
 
 
 
 
551
  logger.error(f"Failed to parse JSON response. Error: {e}")
552
- logger.debug(f"Response content: {response[:500]}...") # Log first 500 chars
 
 
553
  # Fallback to a default action if parsing fails
554
  return {
555
  "action_type": "ask",
556
- "content": "Could you please clarify the next best step? The previous analysis was inconclusive.",
557
- "reasoning": "Fallback due to parsing error."
 
 
558
  }
559
-
560
  def _estimate_cost(self, tests: Union[List[str], str]) -> int:
561
  """Estimates the cost of diagnostic tests."""
562
  if isinstance(tests, str):
563
  tests = [tests]
564
-
565
  cost = 0
566
  for test in tests:
567
  test_lower = test.lower().strip()
568
-
569
  # Enhanced cost matching with multiple strategies
570
  cost_found = False
571
-
572
  # Strategy 1: Exact match
573
  if test_lower in self.test_cost_db:
574
  cost += self.test_cost_db[test_lower]
575
  cost_found = True
576
  continue
577
-
578
  # Strategy 2: Partial match (find best matching key)
579
  best_match = None
580
  best_match_length = 0
@@ -583,56 +635,87 @@ class MaiDxOrchestrator:
583
  if len(cost_key) > best_match_length:
584
  best_match = cost_key
585
  best_match_length = len(cost_key)
586
-
587
  if best_match:
588
  cost += self.test_cost_db[best_match]
589
  cost_found = True
590
  continue
591
-
592
  # Strategy 3: Keyword-based matching
593
- if any(keyword in test_lower for keyword in ['biopsy', 'tissue']):
594
- cost += self.test_cost_db.get('biopsy', 800)
 
 
 
595
  cost_found = True
596
- elif any(keyword in test_lower for keyword in ['mri', 'magnetic']):
597
- cost += self.test_cost_db.get('mri', 1500)
 
 
 
598
  cost_found = True
599
- elif any(keyword in test_lower for keyword in ['ct', 'computed tomography']):
600
- cost += self.test_cost_db.get('ct scan', 1200)
 
 
 
601
  cost_found = True
602
- elif any(keyword in test_lower for keyword in ['xray', 'x-ray', 'radiograph']):
603
- cost += self.test_cost_db.get('chest x-ray', 200)
 
 
 
604
  cost_found = True
605
- elif any(keyword in test_lower for keyword in ['blood', 'serum', 'plasma']):
 
 
 
606
  cost += 100 # Basic blood test cost
607
  cost_found = True
608
- elif any(keyword in test_lower for keyword in ['culture', 'sensitivity']):
609
- cost += self.test_cost_db.get('culture', 150)
 
 
 
610
  cost_found = True
611
- elif any(keyword in test_lower for keyword in ['immunohistochemistry', 'ihc']):
612
- cost += self.test_cost_db.get('immunohistochemistry', 400)
 
 
 
 
 
613
  cost_found = True
614
-
615
  # Strategy 4: Default cost for unknown tests
616
  if not cost_found:
617
- cost += self.test_cost_db['default']
618
- logger.debug(f"Using default cost for unknown test: {test}")
619
-
 
 
620
  return cost
621
 
622
  def _run_panel_deliberation(self) -> Action:
623
  """Orchestrates one round of debate among the virtual panel to decide the next action."""
624
- logger.info("🩺 Virtual medical panel deliberation commenced - analyzing patient case")
625
- logger.debug("Panel members: Dr. Hypothesis, Dr. Test-Chooser, Dr. Challenger, Dr. Stewardship, Dr. Checklist")
 
 
 
 
626
  panel_conversation = Conversation(
627
- time_enabled=True,
628
- autosave=False,
629
- save_enabled=False
630
  )
631
-
632
  # Prepare comprehensive panel context
633
  remaining_budget = self.initial_budget - self.cumulative_cost
634
- budget_status = "EXCEEDED" if remaining_budget < 0 else f"${remaining_budget:,}"
635
-
 
 
 
 
636
  panel_context = f"""
637
  DIAGNOSTIC CASE STATUS - ROUND {len(self.conversation.return_history_as_string().split('Gatekeeper:')) - 1}
638
 
@@ -657,11 +740,17 @@ class MaiDxOrchestrator:
657
  # For instant mode, skip deliberation and go straight to diagnosis
658
  action_dict = {
659
  "action_type": "diagnose",
660
- "content": self.differential_diagnosis.split('\n')[0] if '\n' in self.differential_diagnosis else self.differential_diagnosis,
661
- "reasoning": "Instant diagnosis mode - providing immediate assessment based on initial presentation"
 
 
 
 
 
 
662
  }
663
  return Action(**action_dict)
664
-
665
  if self.mode == "question_only":
666
  # For question-only mode, prevent test ordering
667
  panel_context += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY ask patient questions, not order diagnostic tests."
@@ -670,67 +759,127 @@ class MaiDxOrchestrator:
670
  # Sequential expert deliberation with enhanced methodology
671
  try:
672
  # Dr. Hypothesis - Differential diagnosis and probability assessment
673
- logger.info("🧠 Dr. Hypothesis analyzing differential diagnosis...")
674
- hypothesis = self.agents[AgentRole.HYPOTHESIS].run(panel_conversation.get_str())
675
- self.differential_diagnosis = hypothesis # Update main state
676
- panel_conversation.add(self.agents[AgentRole.HYPOTHESIS].agent_name, hypothesis)
677
-
 
 
 
 
 
 
 
 
 
678
  # Dr. Test-Chooser - Information value optimization
679
- logger.info("πŸ”¬ Dr. Test-Chooser selecting optimal tests...")
680
- test_choices = self.agents[AgentRole.TEST_CHOOSER].run(panel_conversation.get_str())
681
- panel_conversation.add(self.agents[AgentRole.TEST_CHOOSER].agent_name, test_choices)
 
 
 
 
 
 
 
682
 
683
  # Dr. Challenger - Bias identification and alternative hypotheses
684
- logger.info("πŸ€” Dr. Challenger challenging assumptions...")
685
- challenges = self.agents[AgentRole.CHALLENGER].run(panel_conversation.get_str())
686
- panel_conversation.add(self.agents[AgentRole.CHALLENGER].agent_name, challenges)
 
 
 
 
 
 
 
687
 
688
  # Dr. Stewardship - Cost-effectiveness analysis
689
- logger.info("πŸ’° Dr. Stewardship evaluating cost-effectiveness...")
 
 
690
  stewardship_context = panel_conversation.get_str()
691
  if self.enable_budget_tracking:
692
  stewardship_context += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${self.cumulative_cost}, Remaining: ${remaining_budget}"
693
- stewardship_rec = self.agents[AgentRole.STEWARDSHIP].run(stewardship_context)
694
- panel_conversation.add(self.agents[AgentRole.STEWARDSHIP].agent_name, stewardship_rec)
695
-
 
 
 
 
 
696
  # Dr. Checklist - Quality assurance
697
- logger.info("βœ… Dr. Checklist performing quality control...")
698
- checklist_rep = self.agents[AgentRole.CHECKLIST].run(panel_conversation.get_str())
699
- panel_conversation.add(self.agents[AgentRole.CHECKLIST].agent_name, checklist_rep)
700
-
 
 
 
 
 
 
 
701
  # Consensus Coordinator - Final decision synthesis
702
- logger.info("🀝 Consensus Coordinator synthesizing panel decision...")
 
 
703
  consensus_context = panel_conversation.get_str()
704
-
705
  # Add mode-specific constraints to consensus
706
  if self.mode == "budgeted" and remaining_budget <= 0:
707
  consensus_context += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
708
-
709
- consensus_response = self.agents[AgentRole.CONSENSUS].run(consensus_context)
710
- logger.debug(f"Raw consensus response: {consensus_response}")
711
-
 
 
 
 
712
  # Extract the actual text content from agent response
713
- if hasattr(consensus_response, 'content'):
714
  response_text = consensus_response.content
715
  elif isinstance(consensus_response, str):
716
  response_text = consensus_response
717
  else:
718
  response_text = str(consensus_response)
719
-
720
  action_dict = self._parse_json_response(response_text)
721
 
722
  # Validate action based on mode constraints
723
  action = Action(**action_dict)
724
- if self.mode == "question_only" and action.action_type == "test":
725
- logger.warning("Test ordering attempted in question-only mode, converting to ask action")
 
 
 
 
 
726
  action.action_type = "ask"
727
  action.content = "Can you provide more details about the patient's symptoms and history?"
728
- action.reasoning = "Mode constraint: question-only mode active"
 
 
729
 
730
- if self.mode == "budgeted" and action.action_type == "test" and remaining_budget <= 0:
731
- logger.warning("Test ordering attempted with insufficient budget, converting to diagnose action")
 
 
 
 
 
 
732
  action.action_type = "diagnose"
733
- action.content = self.differential_diagnosis.split('\n')[0] if '\n' in self.differential_diagnosis else self.differential_diagnosis
 
 
 
 
734
  action.reasoning = "Budget constraint: insufficient funds for additional testing"
735
 
736
  return action
@@ -741,13 +890,15 @@ class MaiDxOrchestrator:
741
  return Action(
742
  action_type="ask",
743
  content="Could you please provide more information about the patient's current condition?",
744
- reasoning=f"Fallback due to panel deliberation error: {str(e)}"
745
  )
746
 
747
- def _interact_with_gatekeeper(self, action: Action, full_case_details: str) -> str:
 
 
748
  """Sends the panel's action to the Gatekeeper and returns its response."""
749
  gatekeeper = self.agents[AgentRole.GATEKEEPER]
750
-
751
  if action.action_type == "ask":
752
  request = f"Question: {action.content}"
753
  elif action.action_type == "test":
@@ -765,11 +916,13 @@ class MaiDxOrchestrator:
765
  Request from Diagnostic Agent:
766
  {request}
767
  """
768
-
769
  response = gatekeeper.run(prompt)
770
  return response
771
 
772
- def _judge_diagnosis(self, candidate_diagnosis: str, ground_truth: str) -> Dict[str, Any]:
 
 
773
  """Uses the Judge agent to evaluate the final diagnosis."""
774
  judge = self.agents[AgentRole.JUDGE]
775
  prompt = f"""
@@ -778,18 +931,25 @@ class MaiDxOrchestrator:
778
  Candidate Diagnosis: "{candidate_diagnosis}"
779
  """
780
  response = judge.run(prompt)
781
-
782
  # Simple parsing for demonstration; a more robust solution would use structured output.
783
  try:
784
- score = float(response.split("Score:")[1].split("/")[0].strip())
 
 
785
  reasoning = response.split("Justification:")[1].strip()
786
  except (IndexError, ValueError):
787
  score = 0.0
788
  reasoning = "Could not parse judge's response."
789
-
790
  return {"score": score, "reasoning": reasoning}
791
 
792
- def run(self, initial_case_info: str, full_case_details: str, ground_truth_diagnosis: str) -> DiagnosisResult:
 
 
 
 
 
793
  """
794
  Executes the full sequential diagnostic process.
795
 
@@ -802,90 +962,152 @@ class MaiDxOrchestrator:
802
  DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
803
  """
804
  start_time = time.time()
805
- self.conversation.add("Gatekeeper", f"Initial Case Information: {initial_case_info}")
806
-
 
 
 
807
  # Add initial physician visit cost
808
  self.cumulative_cost += self.physician_visit_cost
809
- logger.info(f"Initial physician visit cost: ${self.physician_visit_cost}")
810
-
 
 
811
  final_diagnosis = None
812
  iteration_count = 0
813
-
814
  for i in range(self.max_iterations):
815
  iteration_count = i + 1
816
- logger.info(f"--- Starting Diagnostic Loop {iteration_count}/{self.max_iterations} ---")
817
- logger.info(f"Current cost: ${self.cumulative_cost:,} | Remaining budget: ${self.initial_budget - self.cumulative_cost:,}")
818
-
 
 
 
 
819
  try:
820
  # Panel deliberates to decide on the next action
821
  action = self._run_panel_deliberation()
822
- logger.info(f"βš•οΈ Panel decision: {action.action_type.upper()} -> {action.content}")
823
- logger.info(f"πŸ’­ Medical reasoning: {action.reasoning}")
824
-
 
 
 
 
825
  if action.action_type == "diagnose":
826
  final_diagnosis = action.content
827
- logger.info(f"Final diagnosis proposed: {final_diagnosis}")
 
 
828
  break
829
 
830
  # Handle mode-specific constraints
831
- if self.mode == "question_only" and action.action_type == "test":
832
- logger.warning("Test ordering blocked in question-only mode")
 
 
 
 
 
833
  continue
834
-
835
- if self.mode == "budgeted" and action.action_type == "test":
 
 
 
836
  # Check if we can afford the tests
837
- estimated_test_cost = self._estimate_cost(action.content)
838
- if self.cumulative_cost + estimated_test_cost > self.initial_budget:
839
- logger.warning(f"Test cost ${estimated_test_cost} would exceed budget. Skipping tests.")
 
 
 
 
 
 
 
840
  continue
841
 
842
  # Interact with the Gatekeeper
843
- response = self._interact_with_gatekeeper(action, full_case_details)
 
 
844
  self.conversation.add("Gatekeeper", response)
845
-
846
  # Update costs based on action type
847
  if action.action_type == "test":
848
  test_cost = self._estimate_cost(action.content)
849
  self.cumulative_cost += test_cost
850
  logger.info(f"Tests ordered: {action.content}")
851
- logger.info(f"Test cost: ${test_cost:,} | Cumulative cost: ${self.cumulative_cost:,}")
 
 
852
  elif action.action_type == "ask":
853
  # Questions are part of the same visit until tests are ordered
854
  logger.info(f"Questions asked: {action.content}")
855
- logger.info(f"No additional cost for questions in same visit")
 
 
856
 
857
  # Check budget constraints for budgeted mode
858
- if self.mode == "budgeted" and self.cumulative_cost >= self.initial_budget:
859
- logger.warning("Budget limit reached. Forcing final diagnosis.")
 
 
 
 
 
860
  # Use current differential diagnosis or make best guess
861
- final_diagnosis = self.differential_diagnosis.split('\n')[0] if '\n' in self.differential_diagnosis else "Diagnosis not reached within budget constraints."
 
 
 
 
862
  break
863
-
864
  except Exception as e:
865
- logger.error(f"Error in diagnostic loop {iteration_count}: {e}")
 
 
866
  # Continue to next iteration or break if critical error
867
  continue
868
-
869
  else:
870
  # Max iterations reached without diagnosis
871
- final_diagnosis = self.differential_diagnosis.split('\n')[0] if '\n' in self.differential_diagnosis else "Diagnosis not reached within maximum iterations."
872
- logger.warning(f"Max iterations ({self.max_iterations}) reached. Using best available diagnosis.")
 
 
 
 
 
 
873
 
874
  # Ensure we have a final diagnosis
875
  if not final_diagnosis or final_diagnosis.strip() == "":
876
- final_diagnosis = "Unable to determine diagnosis within constraints."
877
-
 
 
878
  # Calculate total time
879
  total_time = time.time() - start_time
880
- logger.info(f"Diagnostic session completed in {total_time:.2f} seconds")
 
 
881
 
882
  # Judge the final diagnosis
883
  logger.info("Evaluating final diagnosis...")
884
  try:
885
- judgement = self._judge_diagnosis(final_diagnosis, ground_truth_diagnosis)
 
 
886
  except Exception as e:
887
  logger.error(f"Error in diagnosis evaluation: {e}")
888
- judgement = {"score": 0.0, "reasoning": f"Evaluation error: {str(e)}"}
 
 
 
889
 
890
  # Create comprehensive result
891
  result = DiagnosisResult(
@@ -895,39 +1117,49 @@ class MaiDxOrchestrator:
895
  accuracy_reasoning=judgement["reasoning"],
896
  total_cost=self.cumulative_cost,
897
  iterations=iteration_count,
898
- conversation_history=self.conversation.get_str()
899
  )
900
-
901
- logger.info(f"Diagnostic process completed:")
902
  logger.info(f" Final diagnosis: {final_diagnosis}")
903
  logger.info(f" Ground truth: {ground_truth_diagnosis}")
904
  logger.info(f" Accuracy score: {judgement['score']}/5.0")
905
  logger.info(f" Total cost: ${self.cumulative_cost:,}")
906
  logger.info(f" Iterations: {iteration_count}")
907
-
908
  return result
909
-
910
- def run_ensemble(self, initial_case_info: str, full_case_details: str, ground_truth_diagnosis: str, num_runs: int = 3) -> DiagnosisResult:
 
 
 
 
 
 
911
  """
912
  Runs multiple independent diagnostic sessions and aggregates the results.
913
-
914
  Args:
915
  initial_case_info (str): The initial abstract of the case.
916
  full_case_details (str): The complete case file for the Gatekeeper.
917
  ground_truth_diagnosis (str): The correct final diagnosis for evaluation.
918
  num_runs (int): Number of independent runs to perform.
919
-
920
  Returns:
921
  DiagnosisResult: Aggregated result from ensemble runs.
922
  """
923
- logger.info(f"Starting ensemble run with {num_runs} independent sessions")
924
-
 
 
925
  ensemble_results = []
926
  total_cost = 0
927
-
928
  for run_id in range(num_runs):
929
- logger.info(f"=== Ensemble Run {run_id + 1}/{num_runs} ===")
930
-
 
 
931
  # Create a fresh orchestrator instance for each run
932
  run_orchestrator = MaiDxOrchestrator(
933
  model_name=self.model_name,
@@ -935,38 +1167,52 @@ class MaiDxOrchestrator:
935
  initial_budget=self.initial_budget,
936
  mode="no_budget", # Use no_budget for ensemble runs
937
  physician_visit_cost=self.physician_visit_cost,
938
- enable_budget_tracking=False
939
  )
940
-
941
  # Run the diagnostic session
942
- result = run_orchestrator.run(initial_case_info, full_case_details, ground_truth_diagnosis)
 
 
 
 
943
  ensemble_results.append(result)
944
  total_cost += result.total_cost
945
-
946
- logger.info(f"Run {run_id + 1} completed: {result.final_diagnosis} (Score: {result.accuracy_score})")
947
-
 
 
948
  # Aggregate results using consensus
949
- final_diagnosis = self._aggregate_ensemble_diagnoses([r.final_diagnosis for r in ensemble_results])
950
-
 
 
951
  # Judge the aggregated diagnosis
952
- judgement = self._judge_diagnosis(final_diagnosis, ground_truth_diagnosis)
953
-
 
 
954
  # Calculate average metrics
955
- avg_iterations = sum(r.iterations for r in ensemble_results) / len(ensemble_results)
956
-
 
 
957
  # Combine conversation histories
958
  combined_history = "\n\n=== ENSEMBLE RESULTS ===\n"
959
  for i, result in enumerate(ensemble_results):
960
  combined_history += f"\n--- Run {i+1} ---\n"
961
- combined_history += f"Diagnosis: {result.final_diagnosis}\n"
 
 
962
  combined_history += f"Score: {result.accuracy_score}\n"
963
  combined_history += f"Cost: ${result.total_cost:,}\n"
964
  combined_history += f"Iterations: {result.iterations}\n"
965
-
966
- combined_history += f"\n--- Aggregated Result ---\n"
967
  combined_history += f"Final Diagnosis: {final_diagnosis}\n"
968
  combined_history += f"Reasoning: {judgement['reasoning']}\n"
969
-
970
  ensemble_result = DiagnosisResult(
971
  final_diagnosis=final_diagnosis,
972
  ground_truth=ground_truth_diagnosis,
@@ -974,28 +1220,36 @@ class MaiDxOrchestrator:
974
  accuracy_reasoning=judgement["reasoning"],
975
  total_cost=total_cost, # Sum of all runs
976
  iterations=int(avg_iterations),
977
- conversation_history=combined_history
 
 
 
 
978
  )
979
-
980
- logger.info(f"Ensemble completed: {final_diagnosis} (Score: {judgement['score']})")
981
  return ensemble_result
982
-
983
- def _aggregate_ensemble_diagnoses(self, diagnoses: List[str]) -> str:
 
 
984
  """Aggregates multiple diagnoses from ensemble runs."""
985
  # Simple majority voting or use the most confident diagnosis
986
  if not diagnoses:
987
  return "No diagnosis available"
988
-
989
  # Remove any empty or invalid diagnoses
990
- valid_diagnoses = [d for d in diagnoses if d and d.strip() and "not reached" not in d.lower()]
991
-
 
 
 
 
992
  if not valid_diagnoses:
993
  return diagnoses[0] if diagnoses else "No valid diagnosis"
994
-
995
  # If all diagnoses are the same, return that
996
  if len(set(valid_diagnoses)) == 1:
997
  return valid_diagnoses[0]
998
-
999
  # Use an aggregator agent to select the best diagnosis
1000
  try:
1001
  aggregator_prompt = f"""
@@ -1008,32 +1262,35 @@ class MaiDxOrchestrator:
1008
  Provide the single best diagnosis that represents the medical consensus.
1009
  Consider clinical accuracy, specificity, and completeness.
1010
  """
1011
-
1012
  aggregator = Agent(
1013
  agent_name="Ensemble Aggregator",
1014
  system_prompt=aggregator_prompt,
1015
  model_name=self.model_name,
1016
  max_loops=1,
1017
- print_on=True # Enable printing for aggregator agent
1018
  )
1019
-
1020
  return aggregator.run(aggregator_prompt).strip()
1021
-
1022
  except Exception as e:
1023
  logger.error(f"Error in ensemble aggregation: {e}")
1024
  # Fallback to most common diagnosis
1025
  from collections import Counter
 
1026
  return Counter(valid_diagnoses).most_common(1)[0][0]
1027
-
1028
  @classmethod
1029
- def create_variant(cls, variant: str, **kwargs) -> 'MaiDxOrchestrator':
 
 
1030
  """
1031
  Factory method to create different MAI-DxO variants as described in the paper.
1032
-
1033
  Args:
1034
  variant (str): One of 'instant', 'question_only', 'budgeted', 'no_budget', 'ensemble'
1035
  **kwargs: Additional parameters for the orchestrator
1036
-
1037
  Returns:
1038
  MaiDxOrchestrator: Configured orchestrator instance
1039
  """
@@ -1041,49 +1298,55 @@ class MaiDxOrchestrator:
1041
  "instant": {
1042
  "mode": "instant",
1043
  "max_iterations": 1,
1044
- "enable_budget_tracking": False
1045
  },
1046
  "question_only": {
1047
  "mode": "question_only",
1048
  "max_iterations": 10,
1049
- "enable_budget_tracking": False
1050
  },
1051
  "budgeted": {
1052
  "mode": "budgeted",
1053
  "max_iterations": 10,
1054
  "enable_budget_tracking": True,
1055
- "initial_budget": kwargs.get("budget", 5000)
1056
  },
1057
  "no_budget": {
1058
  "mode": "no_budget",
1059
  "max_iterations": 10,
1060
- "enable_budget_tracking": False
1061
  },
1062
  "ensemble": {
1063
  "mode": "no_budget",
1064
  "max_iterations": 10,
1065
- "enable_budget_tracking": False
1066
- }
1067
  }
1068
-
1069
  if variant not in variant_configs:
1070
- raise ValueError(f"Unknown variant: {variant}. Choose from: {list(variant_configs.keys())}")
1071
-
 
 
1072
  config = variant_configs[variant]
1073
  config.update(kwargs) # Allow overrides
1074
-
1075
  return cls(**config)
1076
 
1077
 
1078
- def run_mai_dxo_demo(case_info: str = None, case_details: str = None, ground_truth: str = None) -> Dict[str, DiagnosisResult]:
 
 
 
 
1079
  """
1080
  Convenience function to run a quick demonstration of MAI-DxO variants.
1081
-
1082
  Args:
1083
  case_info (str): Initial case information. Uses default if None.
1084
  case_details (str): Full case details. Uses default if None.
1085
  ground_truth (str): Ground truth diagnosis. Uses default if None.
1086
-
1087
  Returns:
1088
  Dict[str, DiagnosisResult]: Results from different MAI-DxO variants
1089
  """
@@ -1093,7 +1356,7 @@ def run_mai_dxo_demo(case_info: str = None, case_details: str = None, ground_tru
1093
  "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
1094
  "and bleeding. Symptoms did not abate with antimicrobial therapy."
1095
  )
1096
-
1097
  if not case_details:
1098
  case_details = """
1099
  Patient: 29-year-old female.
@@ -1107,31 +1370,39 @@ def run_mai_dxo_demo(case_info: str = None, case_details: str = None, ground_tru
1107
  Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
1108
  Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
1109
  """
1110
-
1111
  if not ground_truth:
1112
  ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
1113
-
1114
  results = {}
1115
-
1116
  # Test key variants
1117
  variants = ["no_budget", "budgeted", "question_only"]
1118
-
1119
  for variant in variants:
1120
  try:
1121
  logger.info(f"Running MAI-DxO variant: {variant}")
1122
-
1123
  if variant == "budgeted":
1124
- orchestrator = MaiDxOrchestrator.create_variant(variant, budget=3000, model_name="gemini/gemini-2.5-flash")
 
 
 
 
1125
  else:
1126
- orchestrator = MaiDxOrchestrator.create_variant(variant, model_name="gemini/gemini-2.5-flash")
1127
-
1128
- result = orchestrator.run(case_info, case_details, ground_truth)
 
 
 
 
1129
  results[variant] = result
1130
-
1131
  except Exception as e:
1132
  logger.error(f"Error running variant {variant}: {e}")
1133
  results[variant] = None
1134
-
1135
  return results
1136
 
1137
 
@@ -1141,7 +1412,7 @@ if __name__ == "__main__":
1141
  "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
1142
  "and bleeding. Symptoms did not abate with antimicrobial therapy."
1143
  )
1144
-
1145
  full_case = """
1146
  Patient: 29-year-old female.
1147
  History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
@@ -1155,55 +1426,65 @@ if __name__ == "__main__":
1155
  Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
1156
  Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
1157
  """
1158
-
1159
  ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
1160
-
1161
  # --- Demonstrate Different MAI-DxO Variants ---
1162
  try:
1163
- print("\n" + "="*80)
1164
- print(" MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK")
1165
- print(" Implementation based on the NEJM Research Paper")
1166
- print("="*80)
1167
-
 
 
 
 
1168
  # Test different variants as described in the paper
1169
  variants_to_test = [
1170
- ("no_budget", "Standard MAI-DxO with no budget constraints"),
 
 
 
1171
  ("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
1172
- ("question_only", "Question-only variant (no diagnostic tests)"),
 
 
 
1173
  ]
1174
-
1175
  results = {}
1176
-
1177
  for variant_name, description in variants_to_test:
1178
  print(f"\n{'='*60}")
1179
  print(f"Testing Variant: {variant_name.upper()}")
1180
  print(f"Description: {description}")
1181
- print('='*60)
1182
-
1183
  # Create the variant
1184
  if variant_name == "budgeted":
1185
  orchestrator = MaiDxOrchestrator.create_variant(
1186
- variant_name,
1187
  budget=3000,
1188
  model_name="gemini/gemini-2.5-flash",
1189
- max_iterations=5
1190
  )
1191
  else:
1192
  orchestrator = MaiDxOrchestrator.create_variant(
1193
  variant_name,
1194
- model_name="gemini/gemini-2.5-flash",
1195
- max_iterations=5
1196
  )
1197
-
1198
  # Run the diagnostic process
1199
  result = orchestrator.run(
1200
  initial_case_info=initial_info,
1201
  full_case_details=full_case,
1202
- ground_truth_diagnosis=ground_truth
1203
  )
1204
-
1205
  results[variant_name] = result
1206
-
1207
  # Display results
1208
  print(f"\nπŸš€ Final Diagnosis: {result.final_diagnosis}")
1209
  print(f"🎯 Ground Truth: {result.ground_truth}")
@@ -1212,50 +1493,72 @@ if __name__ == "__main__":
1212
  print(f"πŸ’° Total Cost: ${result.total_cost:,}")
1213
  print(f"πŸ”„ Iterations: {result.iterations}")
1214
  print(f"⏱️ Mode: {orchestrator.mode}")
1215
-
1216
  # Demonstrate ensemble approach
1217
  print(f"\n{'='*60}")
1218
  print("Testing Variant: ENSEMBLE")
1219
- print("Description: Multiple independent runs with consensus aggregation")
1220
- print('='*60)
1221
-
 
 
1222
  ensemble_orchestrator = MaiDxOrchestrator.create_variant(
1223
  "ensemble",
1224
  model_name="gemini/gemini-2.5-flash",
1225
- max_iterations=3 # Shorter iterations for ensemble
1226
  )
1227
-
1228
  ensemble_result = ensemble_orchestrator.run_ensemble(
1229
  initial_case_info=initial_info,
1230
  full_case_details=full_case,
1231
  ground_truth_diagnosis=ground_truth,
1232
- num_runs=2 # Reduced for demo
1233
  )
1234
-
1235
  results["ensemble"] = ensemble_result
1236
-
1237
- print(f"\nπŸš€ Ensemble Diagnosis: {ensemble_result.final_diagnosis}")
 
 
1238
  print(f"🎯 Ground Truth: {ensemble_result.ground_truth}")
1239
- print(f"⭐ Ensemble Score: {ensemble_result.accuracy_score}/5.0")
1240
- print(f"πŸ’° Total Ensemble Cost: ${ensemble_result.total_cost:,}")
1241
-
 
 
 
 
1242
  # --- Summary Comparison ---
1243
  print(f"\n{'='*80}")
1244
  print(" RESULTS SUMMARY")
1245
- print('='*80)
1246
- print(f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}")
1247
- print('-'*80)
1248
-
 
 
1249
  for variant_name, result in results.items():
1250
- match_status = "βœ“ Match" if result.accuracy_score >= 4.0 else "βœ— No Match"
1251
- print(f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}")
1252
-
 
 
 
 
 
 
1253
  print(f"\n{'='*80}")
1254
- print("Implementation successfully demonstrates the MAI-DxO framework")
1255
- print("as described in 'Sequential Diagnosis with Language Models' paper")
1256
- print('='*80)
 
 
 
 
1257
 
1258
  except Exception as e:
1259
- logger.exception(f"An error occurred during the diagnostic session: {e}")
 
 
1260
  print(f"\n❌ Error occurred: {e}")
1261
- print("Please check your model configuration and API keys.")
 
1
  """
2
  MAI Diagnostic Orchestrator (MAI-DxO)
3
 
4
+ This script provides a complete implementation of the "Sequential Diagnosis with Language Models"
5
+ paper, using the `swarms` framework. It simulates a virtual panel of physician-agents to perform
6
  iterative medical diagnosis with cost-effectiveness optimization.
7
 
8
  Based on the paper: "Sequential Diagnosis with Language Models"
 
20
  # Standard MAI-DxO usage
21
  orchestrator = MaiDxOrchestrator(model_name="gemini/gemini-2.5-flash")
22
  result = orchestrator.run(initial_case_info, full_case_details, ground_truth)
23
+
24
  # Budget-constrained variant
25
  budgeted_orchestrator = MaiDxOrchestrator.create_variant("budgeted", budget=5000)
26
+
27
  # Ensemble approach
28
  ensemble_result = orchestrator.run_ensemble(initial_case_info, full_case_details, ground_truth)
29
  """
30
 
31
+ # Enable debug mode if environment variable is set
32
+ import os
33
  import json
34
  import sys
35
  import time
36
  from dataclasses import dataclass
37
  from enum import Enum
38
+ from typing import Any, Dict, List, Union, Literal
39
 
40
  from loguru import logger
41
  from pydantic import BaseModel, Field
 
49
  sys.stdout,
50
  level="INFO",
51
  format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
52
+ colorize=True,
53
  )
54
 
55
+
 
56
  if os.getenv("MAIDX_DEBUG", "").lower() in ("1", "true", "yes"):
57
  logger.add(
58
  "logs/maidx_debug_{time:YYYY-MM-DD}.log",
59
  level="DEBUG",
60
  format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
61
  rotation="1 day",
62
+ retention="3 days",
63
+ )
64
+ logger.info(
65
+ "πŸ› Debug logging enabled - logs will be written to logs/ directory"
66
  )
 
67
 
68
  # File handler for persistent logging (optional - uncomment if needed)
69
  # logger.add(
70
  # "logs/mai_dxo_{time:YYYY-MM-DD}.log",
71
  # rotation="1 day",
72
+ # retention="7 days",
73
  # level="DEBUG",
74
  # format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
75
  # compression="zip"
 
77
 
78
  # --- Data Structures and Enums ---
79
 
80
+
81
  class AgentRole(Enum):
82
  """Enumeration of roles for the virtual physician panel."""
83
+
84
  HYPOTHESIS = "Dr. Hypothesis"
85
  TEST_CHOOSER = "Dr. Test-Chooser"
86
  CHALLENGER = "Dr. Challenger"
 
90
  GATEKEEPER = "Gatekeeper"
91
  JUDGE = "Judge"
92
 
93
+
94
  @dataclass
95
  class DiagnosisResult:
96
  """Stores the final result of a diagnostic session."""
97
+
98
  final_diagnosis: str
99
  ground_truth: str
100
  accuracy_score: float
 
103
  iterations: int
104
  conversation_history: str
105
 
106
+
107
  class Action(BaseModel):
108
  """Pydantic model for a structured action decided by the consensus agent."""
109
+
110
+ action_type: Literal["ask", "test", "diagnose"] = Field(
111
+ ..., description="The type of action to perform."
112
+ )
113
+ content: Union[str, List[str]] = Field(
114
+ ...,
115
+ description="The content of the action (question, test name, or diagnosis).",
116
+ )
117
+ reasoning: str = Field(
118
+ ..., description="The reasoning behind choosing this action."
119
+ )
120
+
121
 
122
  # --- Main Orchestrator Class ---
123
 
124
+
125
  class MaiDxOrchestrator:
126
  """
127
  Implements the MAI Diagnostic Orchestrator (MAI-DxO) framework.
128
  This class orchestrates a virtual panel of AI agents to perform sequential medical diagnosis,
129
  evaluates the final diagnosis, and tracks costs.
130
  """
131
+
132
  def __init__(
133
  self,
134
  model_name: str = "gemini/gemini-2.5-flash",
 
155
  self.mode = mode
156
  self.physician_visit_cost = physician_visit_cost
157
  self.enable_budget_tracking = enable_budget_tracking
158
+
159
  self.cumulative_cost = 0
160
  self.differential_diagnosis = "Not yet formulated."
161
  self.conversation = Conversation(
162
+ time_enabled=True, autosave=False, save_enabled=False
 
 
163
  )
164
+
165
  # Enhanced cost model based on the paper's methodology
166
  self.test_cost_db = {
167
  "default": 150,
 
194
  }
195
 
196
  self._init_agents()
197
+ logger.info(
198
+ f"πŸ₯ MAI Diagnostic Orchestrator initialized successfully in '{mode}' mode with budget ${initial_budget:,}"
199
+ )
200
 
201
  def _init_agents(self):
202
  """Initializes all required agents with their specific roles and prompts."""
 
206
  system_prompt=self._get_prompt_for_role(role),
207
  model_name=self.model_name,
208
  max_loops=1,
209
+ output_type=(
210
+ "json" if role == AgentRole.CONSENSUS else "str"
211
+ ),
212
  print_on=True, # Enable printing for all agents to see outputs
213
+ )
214
+ for role in AgentRole
215
  }
216
+ logger.info(
217
+ f"πŸ‘₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
218
+ )
219
 
220
  def _get_prompt_for_role(self, role: AgentRole) -> str:
221
  """Returns the system prompt for a given agent role."""
222
  prompts = {
223
+ AgentRole.HYPOTHESIS: (
224
+ """
225
  You are Dr. Hypothesis, a specialist in maintaining differential diagnoses. Your role is critical to the diagnostic process.
226
 
227
  CORE RESPONSIBILITIES:
 
246
  - Evidence that contradicts or challenges each hypothesis
247
 
248
  Remember: Your differential drives the entire diagnostic process. Be thorough, evidence-based, and adaptive.
249
+ """
250
+ ),
251
+ AgentRole.TEST_CHOOSER: (
252
+ """
253
  You are Dr. Test-Chooser, a specialist in diagnostic test selection and information theory.
254
 
255
  CORE RESPONSIBILITIES:
 
278
  - How results will change management decisions
279
 
280
  Focus on tests that will most efficiently narrow the differential diagnosis.
281
+ """
282
+ ),
283
+ AgentRole.CHALLENGER: (
284
+ """
285
  You are Dr. Challenger, the critical thinking specialist and devil's advocate.
286
 
287
  CORE RESPONSIBILITIES:
 
312
  - Red flags or concerning patterns that need attention
313
 
314
  Be constructively critical - your role is to strengthen diagnostic accuracy through rigorous challenge.
315
+ """
316
+ ),
317
+ AgentRole.STEWARDSHIP: (
318
+ """
319
  You are Dr. Stewardship, the resource optimization and cost-effectiveness specialist.
320
 
321
  CORE RESPONSIBILITIES:
 
351
  - Cumulative cost considerations
352
 
353
  Your goal: Maximum diagnostic accuracy at minimum necessary cost.
354
+ """
355
+ ),
356
+ AgentRole.CHECKLIST: (
357
+ """
358
  You are Dr. Checklist, the quality assurance and consistency specialist.
359
 
360
  CORE RESPONSIBILITIES:
 
385
  - Process improvement suggestions
386
 
387
  Keep your feedback concise but comprehensive. Flag any issues that could compromise diagnostic quality.
388
+ """
389
+ ),
390
+ AgentRole.CONSENSUS: (
391
+ """
392
  You are the Consensus Coordinator, responsible for synthesizing the virtual panel's expertise into a single, optimal decision.
393
 
394
  CORE RESPONSIBILITIES:
 
422
  For action_type "diagnose": content should be the complete, specific final diagnosis
423
 
424
  Make the decision that best advances accurate, cost-effective diagnosis.
425
+ """
426
+ ),
427
+ AgentRole.GATEKEEPER: (
428
+ """
429
  You are the Gatekeeper, the clinical information oracle with complete access to the patient case file.
430
 
431
  CORE RESPONSIBILITIES:
 
462
  - Professional medical terminology
463
 
464
  Your role is crucial: provide complete, accurate clinical information while maintaining the challenge of the diagnostic process.
465
+ """
466
+ ),
467
+ AgentRole.JUDGE: (
468
+ """
469
  You are the Judge, the diagnostic accuracy evaluation specialist.
470
 
471
  CORE RESPONSIBILITIES:
 
521
 
522
  Maintain high standards while recognizing legitimate diagnostic variability in medical practice.
523
  """
524
+ ),
525
  }
526
  return prompts[role]
527
+
528
  def _parse_json_response(self, response: str) -> Dict[str, Any]:
529
  """Safely parses a JSON string, returning a dictionary."""
530
  try:
 
540
  start_idx += len(start_marker)
541
  end_idx = response.find(end_marker, start_idx)
542
  if end_idx != -1:
543
+ json_content = response[
544
+ start_idx:end_idx
545
+ ].strip()
546
  return json.loads(json_content)
547
+
548
  # Try to find JSON-like content in the response
549
+ lines = response.split("\n")
550
  json_lines = []
551
  in_json = False
552
  brace_count = 0
553
+
554
  for line in lines:
555
  stripped_line = line.strip()
556
+ if stripped_line.startswith("{") and not in_json:
557
  in_json = True
558
  json_lines = [line] # Start fresh
559
+ brace_count = line.count("{") - line.count(
560
+ "}"
561
+ )
562
  elif in_json:
563
  json_lines.append(line)
564
+ brace_count += line.count("{") - line.count(
565
+ "}"
566
+ )
567
+ if (
568
+ brace_count <= 0
569
+ ): # Balanced braces, end of JSON
570
  break
571
+
572
  if json_lines and in_json:
573
+ json_content = "\n".join(json_lines)
574
  return json.loads(json_content)
575
+
576
  # Try to extract JSON from text that might contain other content
577
  import re
578
+
579
  # Look for JSON pattern in the text
580
+ json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
581
+ matches = re.findall(
582
+ json_pattern, response, re.DOTALL
583
+ )
584
+
585
  for match in matches:
586
  try:
587
  return json.loads(match)
588
  except json.JSONDecodeError:
589
  continue
590
+
591
  # Direct parsing attempt as fallback
592
  return json.loads(response)
593
+
594
+ except (
595
+ json.JSONDecodeError,
596
+ IndexError,
597
+ AttributeError,
598
+ ) as e:
599
  logger.error(f"Failed to parse JSON response. Error: {e}")
600
+ logger.debug(
601
+ f"Response content: {response[:500]}..."
602
+ ) # Log first 500 chars
603
  # Fallback to a default action if parsing fails
604
  return {
605
  "action_type": "ask",
606
+ "content": (
607
+ "Could you please clarify the next best step? The previous analysis was inconclusive."
608
+ ),
609
+ "reasoning": "Fallback due to parsing error.",
610
  }
611
+
612
  def _estimate_cost(self, tests: Union[List[str], str]) -> int:
613
  """Estimates the cost of diagnostic tests."""
614
  if isinstance(tests, str):
615
  tests = [tests]
616
+
617
  cost = 0
618
  for test in tests:
619
  test_lower = test.lower().strip()
620
+
621
  # Enhanced cost matching with multiple strategies
622
  cost_found = False
623
+
624
  # Strategy 1: Exact match
625
  if test_lower in self.test_cost_db:
626
  cost += self.test_cost_db[test_lower]
627
  cost_found = True
628
  continue
629
+
630
  # Strategy 2: Partial match (find best matching key)
631
  best_match = None
632
  best_match_length = 0
 
635
  if len(cost_key) > best_match_length:
636
  best_match = cost_key
637
  best_match_length = len(cost_key)
638
+
639
  if best_match:
640
  cost += self.test_cost_db[best_match]
641
  cost_found = True
642
  continue
643
+
644
  # Strategy 3: Keyword-based matching
645
+ if any(
646
+ keyword in test_lower
647
+ for keyword in ["biopsy", "tissue"]
648
+ ):
649
+ cost += self.test_cost_db.get("biopsy", 800)
650
  cost_found = True
651
+ elif any(
652
+ keyword in test_lower
653
+ for keyword in ["mri", "magnetic"]
654
+ ):
655
+ cost += self.test_cost_db.get("mri", 1500)
656
  cost_found = True
657
+ elif any(
658
+ keyword in test_lower
659
+ for keyword in ["ct", "computed tomography"]
660
+ ):
661
+ cost += self.test_cost_db.get("ct scan", 1200)
662
  cost_found = True
663
+ elif any(
664
+ keyword in test_lower
665
+ for keyword in ["xray", "x-ray", "radiograph"]
666
+ ):
667
+ cost += self.test_cost_db.get("chest x-ray", 200)
668
  cost_found = True
669
+ elif any(
670
+ keyword in test_lower
671
+ for keyword in ["blood", "serum", "plasma"]
672
+ ):
673
  cost += 100 # Basic blood test cost
674
  cost_found = True
675
+ elif any(
676
+ keyword in test_lower
677
+ for keyword in ["culture", "sensitivity"]
678
+ ):
679
+ cost += self.test_cost_db.get("culture", 150)
680
  cost_found = True
681
+ elif any(
682
+ keyword in test_lower
683
+ for keyword in ["immunohistochemistry", "ihc"]
684
+ ):
685
+ cost += self.test_cost_db.get(
686
+ "immunohistochemistry", 400
687
+ )
688
  cost_found = True
689
+
690
  # Strategy 4: Default cost for unknown tests
691
  if not cost_found:
692
+ cost += self.test_cost_db["default"]
693
+ logger.debug(
694
+ f"Using default cost for unknown test: {test}"
695
+ )
696
+
697
  return cost
698
 
699
  def _run_panel_deliberation(self) -> Action:
700
  """Orchestrates one round of debate among the virtual panel to decide the next action."""
701
+ logger.info(
702
+ "🩺 Virtual medical panel deliberation commenced - analyzing patient case"
703
+ )
704
+ logger.debug(
705
+ "Panel members: Dr. Hypothesis, Dr. Test-Chooser, Dr. Challenger, Dr. Stewardship, Dr. Checklist"
706
+ )
707
  panel_conversation = Conversation(
708
+ time_enabled=True, autosave=False, save_enabled=False
 
 
709
  )
710
+
711
  # Prepare comprehensive panel context
712
  remaining_budget = self.initial_budget - self.cumulative_cost
713
+ budget_status = (
714
+ "EXCEEDED"
715
+ if remaining_budget < 0
716
+ else f"${remaining_budget:,}"
717
+ )
718
+
719
  panel_context = f"""
720
  DIAGNOSTIC CASE STATUS - ROUND {len(self.conversation.return_history_as_string().split('Gatekeeper:')) - 1}
721
 
 
740
  # For instant mode, skip deliberation and go straight to diagnosis
741
  action_dict = {
742
  "action_type": "diagnose",
743
+ "content": (
744
+ self.differential_diagnosis.split("\n")[0]
745
+ if "\n" in self.differential_diagnosis
746
+ else self.differential_diagnosis
747
+ ),
748
+ "reasoning": (
749
+ "Instant diagnosis mode - providing immediate assessment based on initial presentation"
750
+ ),
751
  }
752
  return Action(**action_dict)
753
+
754
  if self.mode == "question_only":
755
  # For question-only mode, prevent test ordering
756
  panel_context += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY ask patient questions, not order diagnostic tests."
 
759
  # Sequential expert deliberation with enhanced methodology
760
  try:
761
  # Dr. Hypothesis - Differential diagnosis and probability assessment
762
+ logger.info(
763
+ "🧠 Dr. Hypothesis analyzing differential diagnosis..."
764
+ )
765
+ hypothesis = self.agents[AgentRole.HYPOTHESIS].run(
766
+ panel_conversation.get_str()
767
+ )
768
+ self.differential_diagnosis = (
769
+ hypothesis # Update main state
770
+ )
771
+ panel_conversation.add(
772
+ self.agents[AgentRole.HYPOTHESIS].agent_name,
773
+ hypothesis,
774
+ )
775
+
776
  # Dr. Test-Chooser - Information value optimization
777
+ logger.info(
778
+ "πŸ”¬ Dr. Test-Chooser selecting optimal tests..."
779
+ )
780
+ test_choices = self.agents[AgentRole.TEST_CHOOSER].run(
781
+ panel_conversation.get_str()
782
+ )
783
+ panel_conversation.add(
784
+ self.agents[AgentRole.TEST_CHOOSER].agent_name,
785
+ test_choices,
786
+ )
787
 
788
  # Dr. Challenger - Bias identification and alternative hypotheses
789
+ logger.info(
790
+ "πŸ€” Dr. Challenger challenging assumptions..."
791
+ )
792
+ challenges = self.agents[AgentRole.CHALLENGER].run(
793
+ panel_conversation.get_str()
794
+ )
795
+ panel_conversation.add(
796
+ self.agents[AgentRole.CHALLENGER].agent_name,
797
+ challenges,
798
+ )
799
 
800
  # Dr. Stewardship - Cost-effectiveness analysis
801
+ logger.info(
802
+ "πŸ’° Dr. Stewardship evaluating cost-effectiveness..."
803
+ )
804
  stewardship_context = panel_conversation.get_str()
805
  if self.enable_budget_tracking:
806
  stewardship_context += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${self.cumulative_cost}, Remaining: ${remaining_budget}"
807
+ stewardship_rec = self.agents[AgentRole.STEWARDSHIP].run(
808
+ stewardship_context
809
+ )
810
+ panel_conversation.add(
811
+ self.agents[AgentRole.STEWARDSHIP].agent_name,
812
+ stewardship_rec,
813
+ )
814
+
815
  # Dr. Checklist - Quality assurance
816
+ logger.info(
817
+ "βœ… Dr. Checklist performing quality control..."
818
+ )
819
+ checklist_rep = self.agents[AgentRole.CHECKLIST].run(
820
+ panel_conversation.get_str()
821
+ )
822
+ panel_conversation.add(
823
+ self.agents[AgentRole.CHECKLIST].agent_name,
824
+ checklist_rep,
825
+ )
826
+
827
  # Consensus Coordinator - Final decision synthesis
828
+ logger.info(
829
+ "🀝 Consensus Coordinator synthesizing panel decision..."
830
+ )
831
  consensus_context = panel_conversation.get_str()
832
+
833
  # Add mode-specific constraints to consensus
834
  if self.mode == "budgeted" and remaining_budget <= 0:
835
  consensus_context += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
836
+
837
+ consensus_response = self.agents[AgentRole.CONSENSUS].run(
838
+ consensus_context
839
+ )
840
+ logger.debug(
841
+ f"Raw consensus response: {consensus_response}"
842
+ )
843
+
844
  # Extract the actual text content from agent response
845
+ if hasattr(consensus_response, "content"):
846
  response_text = consensus_response.content
847
  elif isinstance(consensus_response, str):
848
  response_text = consensus_response
849
  else:
850
  response_text = str(consensus_response)
851
+
852
  action_dict = self._parse_json_response(response_text)
853
 
854
  # Validate action based on mode constraints
855
  action = Action(**action_dict)
856
+ if (
857
+ self.mode == "question_only"
858
+ and action.action_type == "test"
859
+ ):
860
+ logger.warning(
861
+ "Test ordering attempted in question-only mode, converting to ask action"
862
+ )
863
  action.action_type = "ask"
864
  action.content = "Can you provide more details about the patient's symptoms and history?"
865
+ action.reasoning = (
866
+ "Mode constraint: question-only mode active"
867
+ )
868
 
869
+ if (
870
+ self.mode == "budgeted"
871
+ and action.action_type == "test"
872
+ and remaining_budget <= 0
873
+ ):
874
+ logger.warning(
875
+ "Test ordering attempted with insufficient budget, converting to diagnose action"
876
+ )
877
  action.action_type = "diagnose"
878
+ action.content = (
879
+ self.differential_diagnosis.split("\n")[0]
880
+ if "\n" in self.differential_diagnosis
881
+ else self.differential_diagnosis
882
+ )
883
  action.reasoning = "Budget constraint: insufficient funds for additional testing"
884
 
885
  return action
 
890
  return Action(
891
  action_type="ask",
892
  content="Could you please provide more information about the patient's current condition?",
893
+ reasoning=f"Fallback due to panel deliberation error: {str(e)}",
894
  )
895
 
896
+ def _interact_with_gatekeeper(
897
+ self, action: Action, full_case_details: str
898
+ ) -> str:
899
  """Sends the panel's action to the Gatekeeper and returns its response."""
900
  gatekeeper = self.agents[AgentRole.GATEKEEPER]
901
+
902
  if action.action_type == "ask":
903
  request = f"Question: {action.content}"
904
  elif action.action_type == "test":
 
916
  Request from Diagnostic Agent:
917
  {request}
918
  """
919
+
920
  response = gatekeeper.run(prompt)
921
  return response
922
 
923
+ def _judge_diagnosis(
924
+ self, candidate_diagnosis: str, ground_truth: str
925
+ ) -> Dict[str, Any]:
926
  """Uses the Judge agent to evaluate the final diagnosis."""
927
  judge = self.agents[AgentRole.JUDGE]
928
  prompt = f"""
 
931
  Candidate Diagnosis: "{candidate_diagnosis}"
932
  """
933
  response = judge.run(prompt)
934
+
935
  # Simple parsing for demonstration; a more robust solution would use structured output.
936
  try:
937
+ score = float(
938
+ response.split("Score:")[1].split("/")[0].strip()
939
+ )
940
  reasoning = response.split("Justification:")[1].strip()
941
  except (IndexError, ValueError):
942
  score = 0.0
943
  reasoning = "Could not parse judge's response."
944
+
945
  return {"score": score, "reasoning": reasoning}
946
 
947
+ def run(
948
+ self,
949
+ initial_case_info: str,
950
+ full_case_details: str,
951
+ ground_truth_diagnosis: str,
952
+ ) -> DiagnosisResult:
953
  """
954
  Executes the full sequential diagnostic process.
955
 
 
962
  DiagnosisResult: An object containing the final diagnosis, evaluation, cost, and history.
963
  """
964
  start_time = time.time()
965
+ self.conversation.add(
966
+ "Gatekeeper",
967
+ f"Initial Case Information: {initial_case_info}",
968
+ )
969
+
970
  # Add initial physician visit cost
971
  self.cumulative_cost += self.physician_visit_cost
972
+ logger.info(
973
+ f"Initial physician visit cost: ${self.physician_visit_cost}"
974
+ )
975
+
976
  final_diagnosis = None
977
  iteration_count = 0
978
+
979
  for i in range(self.max_iterations):
980
  iteration_count = i + 1
981
+ logger.info(
982
+ f"--- Starting Diagnostic Loop {iteration_count}/{self.max_iterations} ---"
983
+ )
984
+ logger.info(
985
+ f"Current cost: ${self.cumulative_cost:,} | Remaining budget: ${self.initial_budget - self.cumulative_cost:,}"
986
+ )
987
+
988
  try:
989
  # Panel deliberates to decide on the next action
990
  action = self._run_panel_deliberation()
991
+ logger.info(
992
+ f"βš•οΈ Panel decision: {action.action_type.upper()} -> {action.content}"
993
+ )
994
+ logger.info(
995
+ f"πŸ’­ Medical reasoning: {action.reasoning}"
996
+ )
997
+
998
  if action.action_type == "diagnose":
999
  final_diagnosis = action.content
1000
+ logger.info(
1001
+ f"Final diagnosis proposed: {final_diagnosis}"
1002
+ )
1003
  break
1004
 
1005
  # Handle mode-specific constraints
1006
+ if (
1007
+ self.mode == "question_only"
1008
+ and action.action_type == "test"
1009
+ ):
1010
+ logger.warning(
1011
+ "Test ordering blocked in question-only mode"
1012
+ )
1013
  continue
1014
+
1015
+ if (
1016
+ self.mode == "budgeted"
1017
+ and action.action_type == "test"
1018
+ ):
1019
  # Check if we can afford the tests
1020
+ estimated_test_cost = self._estimate_cost(
1021
+ action.content
1022
+ )
1023
+ if (
1024
+ self.cumulative_cost + estimated_test_cost
1025
+ > self.initial_budget
1026
+ ):
1027
+ logger.warning(
1028
+ f"Test cost ${estimated_test_cost} would exceed budget. Skipping tests."
1029
+ )
1030
  continue
1031
 
1032
  # Interact with the Gatekeeper
1033
+ response = self._interact_with_gatekeeper(
1034
+ action, full_case_details
1035
+ )
1036
  self.conversation.add("Gatekeeper", response)
1037
+
1038
  # Update costs based on action type
1039
  if action.action_type == "test":
1040
  test_cost = self._estimate_cost(action.content)
1041
  self.cumulative_cost += test_cost
1042
  logger.info(f"Tests ordered: {action.content}")
1043
+ logger.info(
1044
+ f"Test cost: ${test_cost:,} | Cumulative cost: ${self.cumulative_cost:,}"
1045
+ )
1046
  elif action.action_type == "ask":
1047
  # Questions are part of the same visit until tests are ordered
1048
  logger.info(f"Questions asked: {action.content}")
1049
+ logger.info(
1050
+ "No additional cost for questions in same visit"
1051
+ )
1052
 
1053
  # Check budget constraints for budgeted mode
1054
+ if (
1055
+ self.mode == "budgeted"
1056
+ and self.cumulative_cost >= self.initial_budget
1057
+ ):
1058
+ logger.warning(
1059
+ "Budget limit reached. Forcing final diagnosis."
1060
+ )
1061
  # Use current differential diagnosis or make best guess
1062
+ final_diagnosis = (
1063
+ self.differential_diagnosis.split("\n")[0]
1064
+ if "\n" in self.differential_diagnosis
1065
+ else "Diagnosis not reached within budget constraints."
1066
+ )
1067
  break
1068
+
1069
  except Exception as e:
1070
+ logger.error(
1071
+ f"Error in diagnostic loop {iteration_count}: {e}"
1072
+ )
1073
  # Continue to next iteration or break if critical error
1074
  continue
1075
+
1076
  else:
1077
  # Max iterations reached without diagnosis
1078
+ final_diagnosis = (
1079
+ self.differential_diagnosis.split("\n")[0]
1080
+ if "\n" in self.differential_diagnosis
1081
+ else "Diagnosis not reached within maximum iterations."
1082
+ )
1083
+ logger.warning(
1084
+ f"Max iterations ({self.max_iterations}) reached. Using best available diagnosis."
1085
+ )
1086
 
1087
  # Ensure we have a final diagnosis
1088
  if not final_diagnosis or final_diagnosis.strip() == "":
1089
+ final_diagnosis = (
1090
+ "Unable to determine diagnosis within constraints."
1091
+ )
1092
+
1093
  # Calculate total time
1094
  total_time = time.time() - start_time
1095
+ logger.info(
1096
+ f"Diagnostic session completed in {total_time:.2f} seconds"
1097
+ )
1098
 
1099
  # Judge the final diagnosis
1100
  logger.info("Evaluating final diagnosis...")
1101
  try:
1102
+ judgement = self._judge_diagnosis(
1103
+ final_diagnosis, ground_truth_diagnosis
1104
+ )
1105
  except Exception as e:
1106
  logger.error(f"Error in diagnosis evaluation: {e}")
1107
+ judgement = {
1108
+ "score": 0.0,
1109
+ "reasoning": f"Evaluation error: {str(e)}",
1110
+ }
1111
 
1112
  # Create comprehensive result
1113
  result = DiagnosisResult(
 
1117
  accuracy_reasoning=judgement["reasoning"],
1118
  total_cost=self.cumulative_cost,
1119
  iterations=iteration_count,
1120
+ conversation_history=self.conversation.get_str(),
1121
  )
1122
+
1123
+ logger.info("Diagnostic process completed:")
1124
  logger.info(f" Final diagnosis: {final_diagnosis}")
1125
  logger.info(f" Ground truth: {ground_truth_diagnosis}")
1126
  logger.info(f" Accuracy score: {judgement['score']}/5.0")
1127
  logger.info(f" Total cost: ${self.cumulative_cost:,}")
1128
  logger.info(f" Iterations: {iteration_count}")
1129
+
1130
  return result
1131
+
1132
+ def run_ensemble(
1133
+ self,
1134
+ initial_case_info: str,
1135
+ full_case_details: str,
1136
+ ground_truth_diagnosis: str,
1137
+ num_runs: int = 3,
1138
+ ) -> DiagnosisResult:
1139
  """
1140
  Runs multiple independent diagnostic sessions and aggregates the results.
1141
+
1142
  Args:
1143
  initial_case_info (str): The initial abstract of the case.
1144
  full_case_details (str): The complete case file for the Gatekeeper.
1145
  ground_truth_diagnosis (str): The correct final diagnosis for evaluation.
1146
  num_runs (int): Number of independent runs to perform.
1147
+
1148
  Returns:
1149
  DiagnosisResult: Aggregated result from ensemble runs.
1150
  """
1151
+ logger.info(
1152
+ f"Starting ensemble run with {num_runs} independent sessions"
1153
+ )
1154
+
1155
  ensemble_results = []
1156
  total_cost = 0
1157
+
1158
  for run_id in range(num_runs):
1159
+ logger.info(
1160
+ f"=== Ensemble Run {run_id + 1}/{num_runs} ==="
1161
+ )
1162
+
1163
  # Create a fresh orchestrator instance for each run
1164
  run_orchestrator = MaiDxOrchestrator(
1165
  model_name=self.model_name,
 
1167
  initial_budget=self.initial_budget,
1168
  mode="no_budget", # Use no_budget for ensemble runs
1169
  physician_visit_cost=self.physician_visit_cost,
1170
+ enable_budget_tracking=False,
1171
  )
1172
+
1173
  # Run the diagnostic session
1174
+ result = run_orchestrator.run(
1175
+ initial_case_info,
1176
+ full_case_details,
1177
+ ground_truth_diagnosis,
1178
+ )
1179
  ensemble_results.append(result)
1180
  total_cost += result.total_cost
1181
+
1182
+ logger.info(
1183
+ f"Run {run_id + 1} completed: {result.final_diagnosis} (Score: {result.accuracy_score})"
1184
+ )
1185
+
1186
  # Aggregate results using consensus
1187
+ final_diagnosis = self._aggregate_ensemble_diagnoses(
1188
+ [r.final_diagnosis for r in ensemble_results]
1189
+ )
1190
+
1191
  # Judge the aggregated diagnosis
1192
+ judgement = self._judge_diagnosis(
1193
+ final_diagnosis, ground_truth_diagnosis
1194
+ )
1195
+
1196
  # Calculate average metrics
1197
+ avg_iterations = sum(
1198
+ r.iterations for r in ensemble_results
1199
+ ) / len(ensemble_results)
1200
+
1201
  # Combine conversation histories
1202
  combined_history = "\n\n=== ENSEMBLE RESULTS ===\n"
1203
  for i, result in enumerate(ensemble_results):
1204
  combined_history += f"\n--- Run {i+1} ---\n"
1205
+ combined_history += (
1206
+ f"Diagnosis: {result.final_diagnosis}\n"
1207
+ )
1208
  combined_history += f"Score: {result.accuracy_score}\n"
1209
  combined_history += f"Cost: ${result.total_cost:,}\n"
1210
  combined_history += f"Iterations: {result.iterations}\n"
1211
+
1212
+ combined_history += "\n--- Aggregated Result ---\n"
1213
  combined_history += f"Final Diagnosis: {final_diagnosis}\n"
1214
  combined_history += f"Reasoning: {judgement['reasoning']}\n"
1215
+
1216
  ensemble_result = DiagnosisResult(
1217
  final_diagnosis=final_diagnosis,
1218
  ground_truth=ground_truth_diagnosis,
 
1220
  accuracy_reasoning=judgement["reasoning"],
1221
  total_cost=total_cost, # Sum of all runs
1222
  iterations=int(avg_iterations),
1223
+ conversation_history=combined_history,
1224
+ )
1225
+
1226
+ logger.info(
1227
+ f"Ensemble completed: {final_diagnosis} (Score: {judgement['score']})"
1228
  )
 
 
1229
  return ensemble_result
1230
+
1231
+ def _aggregate_ensemble_diagnoses(
1232
+ self, diagnoses: List[str]
1233
+ ) -> str:
1234
  """Aggregates multiple diagnoses from ensemble runs."""
1235
  # Simple majority voting or use the most confident diagnosis
1236
  if not diagnoses:
1237
  return "No diagnosis available"
1238
+
1239
  # Remove any empty or invalid diagnoses
1240
+ valid_diagnoses = [
1241
+ d
1242
+ for d in diagnoses
1243
+ if d and d.strip() and "not reached" not in d.lower()
1244
+ ]
1245
+
1246
  if not valid_diagnoses:
1247
  return diagnoses[0] if diagnoses else "No valid diagnosis"
1248
+
1249
  # If all diagnoses are the same, return that
1250
  if len(set(valid_diagnoses)) == 1:
1251
  return valid_diagnoses[0]
1252
+
1253
  # Use an aggregator agent to select the best diagnosis
1254
  try:
1255
  aggregator_prompt = f"""
 
1262
  Provide the single best diagnosis that represents the medical consensus.
1263
  Consider clinical accuracy, specificity, and completeness.
1264
  """
1265
+
1266
  aggregator = Agent(
1267
  agent_name="Ensemble Aggregator",
1268
  system_prompt=aggregator_prompt,
1269
  model_name=self.model_name,
1270
  max_loops=1,
1271
+ print_on=True, # Enable printing for aggregator agent
1272
  )
1273
+
1274
  return aggregator.run(aggregator_prompt).strip()
1275
+
1276
  except Exception as e:
1277
  logger.error(f"Error in ensemble aggregation: {e}")
1278
  # Fallback to most common diagnosis
1279
  from collections import Counter
1280
+
1281
  return Counter(valid_diagnoses).most_common(1)[0][0]
1282
+
1283
  @classmethod
1284
+ def create_variant(
1285
+ cls, variant: str, **kwargs
1286
+ ) -> "MaiDxOrchestrator":
1287
  """
1288
  Factory method to create different MAI-DxO variants as described in the paper.
1289
+
1290
  Args:
1291
  variant (str): One of 'instant', 'question_only', 'budgeted', 'no_budget', 'ensemble'
1292
  **kwargs: Additional parameters for the orchestrator
1293
+
1294
  Returns:
1295
  MaiDxOrchestrator: Configured orchestrator instance
1296
  """
 
1298
  "instant": {
1299
  "mode": "instant",
1300
  "max_iterations": 1,
1301
+ "enable_budget_tracking": False,
1302
  },
1303
  "question_only": {
1304
  "mode": "question_only",
1305
  "max_iterations": 10,
1306
+ "enable_budget_tracking": False,
1307
  },
1308
  "budgeted": {
1309
  "mode": "budgeted",
1310
  "max_iterations": 10,
1311
  "enable_budget_tracking": True,
1312
+ "initial_budget": kwargs.get("budget", 5000),
1313
  },
1314
  "no_budget": {
1315
  "mode": "no_budget",
1316
  "max_iterations": 10,
1317
+ "enable_budget_tracking": False,
1318
  },
1319
  "ensemble": {
1320
  "mode": "no_budget",
1321
  "max_iterations": 10,
1322
+ "enable_budget_tracking": False,
1323
+ },
1324
  }
1325
+
1326
  if variant not in variant_configs:
1327
+ raise ValueError(
1328
+ f"Unknown variant: {variant}. Choose from: {list(variant_configs.keys())}"
1329
+ )
1330
+
1331
  config = variant_configs[variant]
1332
  config.update(kwargs) # Allow overrides
1333
+
1334
  return cls(**config)
1335
 
1336
 
1337
+ def run_mai_dxo_demo(
1338
+ case_info: str = None,
1339
+ case_details: str = None,
1340
+ ground_truth: str = None,
1341
+ ) -> Dict[str, DiagnosisResult]:
1342
  """
1343
  Convenience function to run a quick demonstration of MAI-DxO variants.
1344
+
1345
  Args:
1346
  case_info (str): Initial case information. Uses default if None.
1347
  case_details (str): Full case details. Uses default if None.
1348
  ground_truth (str): Ground truth diagnosis. Uses default if None.
1349
+
1350
  Returns:
1351
  Dict[str, DiagnosisResult]: Results from different MAI-DxO variants
1352
  """
 
1356
  "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
1357
  "and bleeding. Symptoms did not abate with antimicrobial therapy."
1358
  )
1359
+
1360
  if not case_details:
1361
  case_details = """
1362
  Patient: 29-year-old female.
 
1370
  Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
1371
  Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
1372
  """
1373
+
1374
  if not ground_truth:
1375
  ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
1376
+
1377
  results = {}
1378
+
1379
  # Test key variants
1380
  variants = ["no_budget", "budgeted", "question_only"]
1381
+
1382
  for variant in variants:
1383
  try:
1384
  logger.info(f"Running MAI-DxO variant: {variant}")
1385
+
1386
  if variant == "budgeted":
1387
+ orchestrator = MaiDxOrchestrator.create_variant(
1388
+ variant,
1389
+ budget=3000,
1390
+ model_name="gemini/gemini-2.5-flash",
1391
+ )
1392
  else:
1393
+ orchestrator = MaiDxOrchestrator.create_variant(
1394
+ variant, model_name="gemini/gemini-2.5-flash"
1395
+ )
1396
+
1397
+ result = orchestrator.run(
1398
+ case_info, case_details, ground_truth
1399
+ )
1400
  results[variant] = result
1401
+
1402
  except Exception as e:
1403
  logger.error(f"Error running variant {variant}: {e}")
1404
  results[variant] = None
1405
+
1406
  return results
1407
 
1408
 
 
1412
  "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
1413
  "and bleeding. Symptoms did not abate with antimicrobial therapy."
1414
  )
1415
+
1416
  full_case = """
1417
  Patient: 29-year-old female.
1418
  History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
 
1426
  Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
1427
  Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
1428
  """
1429
+
1430
  ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
1431
+
1432
  # --- Demonstrate Different MAI-DxO Variants ---
1433
  try:
1434
+ print("\n" + "=" * 80)
1435
+ print(
1436
+ " MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK"
1437
+ )
1438
+ print(
1439
+ " Implementation based on the NEJM Research Paper"
1440
+ )
1441
+ print("=" * 80)
1442
+
1443
  # Test different variants as described in the paper
1444
  variants_to_test = [
1445
+ (
1446
+ "no_budget",
1447
+ "Standard MAI-DxO with no budget constraints",
1448
+ ),
1449
  ("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
1450
+ (
1451
+ "question_only",
1452
+ "Question-only variant (no diagnostic tests)",
1453
+ ),
1454
  ]
1455
+
1456
  results = {}
1457
+
1458
  for variant_name, description in variants_to_test:
1459
  print(f"\n{'='*60}")
1460
  print(f"Testing Variant: {variant_name.upper()}")
1461
  print(f"Description: {description}")
1462
+ print("=" * 60)
1463
+
1464
  # Create the variant
1465
  if variant_name == "budgeted":
1466
  orchestrator = MaiDxOrchestrator.create_variant(
1467
+ variant_name,
1468
  budget=3000,
1469
  model_name="gemini/gemini-2.5-flash",
1470
+ max_iterations=5,
1471
  )
1472
  else:
1473
  orchestrator = MaiDxOrchestrator.create_variant(
1474
  variant_name,
1475
+ model_name="gemini/gemini-2.5-flash",
1476
+ max_iterations=5,
1477
  )
1478
+
1479
  # Run the diagnostic process
1480
  result = orchestrator.run(
1481
  initial_case_info=initial_info,
1482
  full_case_details=full_case,
1483
+ ground_truth_diagnosis=ground_truth,
1484
  )
1485
+
1486
  results[variant_name] = result
1487
+
1488
  # Display results
1489
  print(f"\nπŸš€ Final Diagnosis: {result.final_diagnosis}")
1490
  print(f"🎯 Ground Truth: {result.ground_truth}")
 
1493
  print(f"πŸ’° Total Cost: ${result.total_cost:,}")
1494
  print(f"πŸ”„ Iterations: {result.iterations}")
1495
  print(f"⏱️ Mode: {orchestrator.mode}")
1496
+
1497
  # Demonstrate ensemble approach
1498
  print(f"\n{'='*60}")
1499
  print("Testing Variant: ENSEMBLE")
1500
+ print(
1501
+ "Description: Multiple independent runs with consensus aggregation"
1502
+ )
1503
+ print("=" * 60)
1504
+
1505
  ensemble_orchestrator = MaiDxOrchestrator.create_variant(
1506
  "ensemble",
1507
  model_name="gemini/gemini-2.5-flash",
1508
+ max_iterations=3, # Shorter iterations for ensemble
1509
  )
1510
+
1511
  ensemble_result = ensemble_orchestrator.run_ensemble(
1512
  initial_case_info=initial_info,
1513
  full_case_details=full_case,
1514
  ground_truth_diagnosis=ground_truth,
1515
+ num_runs=2, # Reduced for demo
1516
  )
1517
+
1518
  results["ensemble"] = ensemble_result
1519
+
1520
+ print(
1521
+ f"\nπŸš€ Ensemble Diagnosis: {ensemble_result.final_diagnosis}"
1522
+ )
1523
  print(f"🎯 Ground Truth: {ensemble_result.ground_truth}")
1524
+ print(
1525
+ f"⭐ Ensemble Score: {ensemble_result.accuracy_score}/5.0"
1526
+ )
1527
+ print(
1528
+ f"πŸ’° Total Ensemble Cost: ${ensemble_result.total_cost:,}"
1529
+ )
1530
+
1531
  # --- Summary Comparison ---
1532
  print(f"\n{'='*80}")
1533
  print(" RESULTS SUMMARY")
1534
+ print("=" * 80)
1535
+ print(
1536
+ f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}"
1537
+ )
1538
+ print("-" * 80)
1539
+
1540
  for variant_name, result in results.items():
1541
+ match_status = (
1542
+ "βœ“ Match"
1543
+ if result.accuracy_score >= 4.0
1544
+ else "βœ— No Match"
1545
+ )
1546
+ print(
1547
+ f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}"
1548
+ )
1549
+
1550
  print(f"\n{'='*80}")
1551
+ print(
1552
+ "Implementation successfully demonstrates the MAI-DxO framework"
1553
+ )
1554
+ print(
1555
+ "as described in 'Sequential Diagnosis with Language Models' paper"
1556
+ )
1557
+ print("=" * 80)
1558
 
1559
  except Exception as e:
1560
+ logger.exception(
1561
+ f"An error occurred during the diagnostic session: {e}"
1562
+ )
1563
  print(f"\n❌ Error occurred: {e}")
1564
+ print("Please check your model configuration and API keys.")
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch
2
- zetascale
3
  swarms
 
 
1
+ loguru
 
2
  swarms
3
+ pydantic