harshalmore31 commited on
Commit
5ebc627
Β·
1 Parent(s): 3acc06b

Implement code changes to enhance functionality and improve performance

Browse files
Files changed (1) hide show
  1. mai_dx/main.py +740 -265
mai_dx/main.py CHANGED
@@ -182,40 +182,34 @@ class DeliberationState:
182
  retry_count: int = 0
183
 
184
  def to_consensus_prompt(self) -> str:
185
- """Generate a structured prompt for the consensus coordinator"""
 
186
  prompt = f"""
187
- You are the Consensus Coordinator. Here is the summary of the panel's deliberation for this turn:
188
 
189
- **Current Differential Diagnosis (from Dr. Hypothesis):**
190
- {self.hypothesis_analysis}
191
 
192
- **Recommended Tests (from Dr. Test-Chooser):**
193
- {self.test_chooser_analysis}
194
 
195
- **Identified Biases & Challenges (from Dr. Challenger):**
196
- {self.challenger_analysis}
197
 
198
- **Cost & Stewardship Concerns (from Dr. Stewardship):**
199
- {self.stewardship_analysis}
200
 
201
- **Quality Control Assessment (from Dr. Checklist):**
202
- {self.checklist_analysis}
203
  """
204
 
205
  if self.stagnation_detected:
206
- prompt += f"""
207
- **CRITICAL INTERVENTION: STAGNATION DETECTED**
208
- The panel is stalled. You MUST propose a different and more decisive action.
209
- If you cannot find a new test or question, you must move to a final diagnosis.
210
- """
211
 
212
  if self.situational_context:
213
- prompt += f"""
214
- **SITUATIONAL CONTEXT:**
215
- {self.situational_context}
216
- """
217
 
218
- prompt += "\nBased on this synthesized input, provide your single best action in the required JSON format."
219
  return prompt
220
 
221
 
@@ -261,12 +255,13 @@ class MaiDxOrchestrator:
261
 
262
  def __init__(
263
  self,
264
- model_name: str = "gpt-4-1106-preview", # Fixed: Use valid GPT-4 Turbo model name
265
  max_iterations: int = 10,
266
  initial_budget: int = 10000,
267
  mode: str = "no_budget", # "instant", "question_only", "budgeted", "no_budget", "ensemble"
268
  physician_visit_cost: int = 300,
269
  enable_budget_tracking: bool = False,
 
270
  ):
271
  """
272
  Initializes the MAI-DxO system with improved architecture.
@@ -278,6 +273,7 @@ class MaiDxOrchestrator:
278
  mode (str): The operational mode of MAI-DxO.
279
  physician_visit_cost (int): Cost per physician visit.
280
  enable_budget_tracking (bool): Whether to enable budget tracking.
 
281
  """
282
  self.model_name = model_name
283
  self.max_iterations = max_iterations
@@ -286,6 +282,12 @@ class MaiDxOrchestrator:
286
  self.physician_visit_cost = physician_visit_cost
287
  self.enable_budget_tracking = enable_budget_tracking
288
 
 
 
 
 
 
 
289
  self.cumulative_cost = 0
290
  self.differential_diagnosis = "Not yet formulated."
291
  self.conversation = Conversation(
@@ -332,35 +334,175 @@ class MaiDxOrchestrator:
332
  )
333
 
334
  def _get_agent_max_tokens(self, role: AgentRole) -> int:
335
- """Get max_tokens for each agent based on their role - significantly increased limits"""
336
  token_limits = {
337
- AgentRole.HYPOTHESIS: 2000, # Increased for comprehensive differential analysis
338
- AgentRole.TEST_CHOOSER: 1500, # Increased for detailed test recommendations
339
- AgentRole.CHALLENGER: 1800, # Increased for thorough bias analysis
340
- AgentRole.STEWARDSHIP: 1200, # Increased for detailed cost analysis
341
- AgentRole.CHECKLIST: 1000, # Increased for comprehensive validation
342
- AgentRole.CONSENSUS: 800, # Increased for detailed reasoning + JSON
343
- AgentRole.GATEKEEPER: 2500, # Increased for detailed clinical findings
344
- AgentRole.JUDGE: 1500, # Increased for comprehensive evaluation
 
345
  }
346
- return token_limits.get(role, 1000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  def _init_agents(self) -> None:
349
  """Initializes all required agents with their specific roles and prompts."""
350
- self.agents = {
351
- role: Agent(
352
- agent_name=role.value,
353
- system_prompt=self._get_prompt_for_role(role),
354
- model_name=self.model_name,
355
- max_loops=1,
356
- output_type=(
357
- "json" if role == AgentRole.CONSENSUS else "str"
358
- ),
359
- print_on=True, # Enable printing for all agents to see outputs
360
- max_tokens=self._get_agent_max_tokens(role), # Role-specific token limits
361
- )
362
- for role in AgentRole
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  logger.info(
365
  f"πŸ‘₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
366
  )
@@ -424,16 +566,16 @@ This case has gone through {case_state.iteration} iterations. Focus on decisive
424
  3. Always explain your Bayesian reasoning clearly
425
  4. Consider epidemiology, pathophysiology, and clinical patterns
426
 
427
- OUTPUT FORMAT (Use full token allocation for comprehensive analysis):
428
- Provide your updated differential diagnosis with:
429
- - Top 3-5 diagnoses with probability estimates (percentages)
430
- - Detailed rationale for each diagnosis
431
- - Key evidence supporting each hypothesis
432
- - Evidence that contradicts or challenges each hypothesis
433
- - Pathophysiological reasoning for each diagnosis
434
- - Risk stratification and urgency considerations
435
-
436
- Remember: Your differential drives the entire diagnostic process. Be thorough, evidence-based, and adaptive. Use your full token allocation to provide comprehensive clinical reasoning.
437
  """,
438
 
439
  AgentRole.TEST_CHOOSER: f"""
@@ -462,17 +604,20 @@ This case has gone through {case_state.iteration} iterations. Focus on decisive
462
  - Avoid redundant tests that won't add new information
463
  - Consider pre-test probability and post-test probability calculations
464
 
465
- OUTPUT FORMAT (Use full token allocation for detailed analysis):
466
- For each recommended test:
 
 
 
 
467
  - Test name (be specific and accurate)
468
  - Primary hypotheses it will help evaluate
469
- - Expected information gain and likelihood ratios
470
- - How results will change management decisions
471
- - Cost considerations and alternatives
472
- - Sequence rationale (why this test now vs. later)
473
- - Expected sensitivity/specificity for the clinical context
474
 
475
- Focus on tests that will most efficiently narrow the differential diagnosis while considering practical constraints.
476
  """,
477
 
478
  AgentRole.CHALLENGER: f"""
@@ -503,16 +648,19 @@ This case has gone through {case_state.iteration} iterations. Focus on decisive
503
  - Advocate for considering multiple conditions simultaneously
504
  - Look for inconsistencies in the clinical presentation
505
 
506
- OUTPUT FORMAT (Use full token allocation for thorough analysis):
507
- - Specific biases you've identified in the current reasoning
508
- - Evidence that contradicts the leading hypotheses
509
- - Alternative diagnoses to consider with reasoning
510
- - Tests that could falsify current assumptions
511
- - Red flags or concerning patterns that need attention
512
- - Analysis of what might be missing from the current approach
513
- - Systematic review of differential diagnosis completeness
514
-
515
- Be constructively critical - your role is to strengthen diagnostic accuracy through rigorous challenge and comprehensive analysis.
 
 
 
516
  """,
517
 
518
  AgentRole.STEWARDSHIP: f"""
@@ -620,18 +768,13 @@ This case has gone through {case_state.iteration} iterations. Focus on decisive
620
  4. **Cost Optimization:** Before finalizing a test, check Dr. Stewardship's input. If a diagnostically equivalent but cheaper alternative is available, select it.
621
  5. **Default to Questions:** If no test meets the criteria or the budget is a major concern, select the most pertinent question to ask.
622
 
623
- **CRITICAL: YOUR RESPONSE MUST BE EXACTLY THIS JSON FORMAT:**
624
- {{
625
- "action_type": "ask" | "test" | "diagnose",
626
- "content": "specific question(s), test name(s), or final diagnosis",
627
- "reasoning": "clear justification synthesizing panel input and citing decision framework step"
628
- }}
629
-
630
  For action_type "ask": content should be specific patient history or physical exam questions
631
  For action_type "test": content should be properly named diagnostic tests (up to 3)
632
  For action_type "diagnose": content should be the complete, specific final diagnosis
633
 
634
- Make the decision that best advances accurate, cost-effective diagnosis. Use your full token allocation for comprehensive reasoning in the reasoning field.
635
  """,
636
 
637
  AgentRole.GATEKEEPER: f"""
@@ -1350,7 +1493,7 @@ NO SYSTEM MESSAGES, NO WRAPPER FORMAT. JUST THE JSON.
1350
  # Initialize structured deliberation state instead of conversational chaining
1351
  deliberation_state = DeliberationState()
1352
 
1353
- # Prepare comprehensive but concise case context for each agent
1354
  remaining_budget = self.initial_budget - case_state.cumulative_cost
1355
  budget_status = (
1356
  "EXCEEDED"
@@ -1358,7 +1501,7 @@ NO SYSTEM MESSAGES, NO WRAPPER FORMAT. JUST THE JSON.
1358
  else f"${remaining_budget:,}"
1359
  )
1360
 
1361
- # Base context for all agents (token-efficient)
1362
  base_context = f"""
1363
  === DIAGNOSTIC CASE STATUS - ROUND {case_state.iteration} ===
1364
 
@@ -1406,34 +1549,50 @@ CURRENT STATE:
1406
  # Dr. Hypothesis - Differential diagnosis and probability assessment
1407
  logger.info("🧠 Dr. Hypothesis analyzing differential diagnosis...")
1408
  hypothesis_prompt = self._get_prompt_for_role(AgentRole.HYPOTHESIS, case_state) + "\n\n" + base_context
1409
- deliberation_state.hypothesis_analysis = self.agents[AgentRole.HYPOTHESIS].run(hypothesis_prompt)
 
 
 
 
 
1410
 
1411
- # Update case state with new differential
1412
- self._update_differential_from_hypothesis(case_state, deliberation_state.hypothesis_analysis)
 
 
 
1413
 
1414
  # Dr. Test-Chooser - Information value optimization
1415
  logger.info("πŸ”¬ Dr. Test-Chooser selecting optimal tests...")
1416
  test_chooser_prompt = self._get_prompt_for_role(AgentRole.TEST_CHOOSER, case_state) + "\n\n" + base_context
1417
  if self.mode == "question_only":
1418
  test_chooser_prompt += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY recommend patient questions, not diagnostic tests."
1419
- deliberation_state.test_chooser_analysis = self.agents[AgentRole.TEST_CHOOSER].run(test_chooser_prompt)
 
 
1420
 
1421
  # Dr. Challenger - Bias identification and alternative hypotheses
1422
  logger.info("πŸ€” Dr. Challenger challenging assumptions...")
1423
  challenger_prompt = self._get_prompt_for_role(AgentRole.CHALLENGER, case_state) + "\n\n" + base_context
1424
- deliberation_state.challenger_analysis = self.agents[AgentRole.CHALLENGER].run(challenger_prompt)
 
 
1425
 
1426
  # Dr. Stewardship - Cost-effectiveness analysis
1427
  logger.info("πŸ’° Dr. Stewardship evaluating cost-effectiveness...")
1428
  stewardship_prompt = self._get_prompt_for_role(AgentRole.STEWARDSHIP, case_state) + "\n\n" + base_context
1429
  if self.enable_budget_tracking:
1430
  stewardship_prompt += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${case_state.cumulative_cost}, Remaining: ${remaining_budget}"
1431
- deliberation_state.stewardship_analysis = self.agents[AgentRole.STEWARDSHIP].run(stewardship_prompt)
 
 
1432
 
1433
  # Dr. Checklist - Quality assurance
1434
  logger.info("βœ… Dr. Checklist performing quality control...")
1435
  checklist_prompt = self._get_prompt_for_role(AgentRole.CHECKLIST, case_state) + "\n\n" + base_context
1436
- deliberation_state.checklist_analysis = self.agents[AgentRole.CHECKLIST].run(checklist_prompt)
 
 
1437
 
1438
  # Consensus Coordinator - Final decision synthesis using structured state
1439
  logger.info("🀝 Consensus Coordinator synthesizing panel decision...")
@@ -1445,11 +1604,8 @@ CURRENT STATE:
1445
  if self.mode == "budgeted" and remaining_budget <= 0:
1446
  consensus_prompt += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
1447
 
1448
- # Use improved JSON parsing with retry logic
1449
- action_dict = self._parse_json_with_retry(
1450
- self.agents[AgentRole.CONSENSUS],
1451
- consensus_prompt
1452
- )
1453
 
1454
  # Validate action based on mode constraints
1455
  action = Action(**action_dict)
@@ -1497,19 +1653,52 @@ CURRENT STATE:
1497
 
1498
  return " | ".join(context_parts) if context_parts else ""
1499
 
1500
- def _update_differential_from_hypothesis(self, case_state: CaseState, hypothesis_analysis: str):
1501
- """Extract and update differential diagnosis from Dr. Hypothesis analysis"""
1502
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1503
  # Simple extraction - look for percentage patterns in the text
1504
  import re
1505
 
1506
  # Update the main differential diagnosis for backward compatibility
1507
- self.differential_diagnosis = hypothesis_analysis
1508
 
1509
  # Try to extract structured probabilities
1510
  # Look for patterns like "Diagnosis: 85%" or "Disease (70%)"
1511
  percentage_pattern = r'([A-Za-z][^:(\n]*?)[\s:]*[\(]?(\d{1,3})%[\)]?'
1512
- matches = re.findall(percentage_pattern, hypothesis_analysis)
1513
 
1514
  new_differential = {}
1515
  for match in matches:
@@ -1520,12 +1709,15 @@ CURRENT STATE:
1520
 
1521
  if new_differential:
1522
  case_state.update_differential(new_differential)
1523
- logger.debug(f"Updated differential: {new_differential}")
1524
 
1525
  except Exception as e:
1526
  logger.debug(f"Could not extract structured differential: {e}")
1527
  # Still update the text version for display
1528
- self.differential_diagnosis = hypothesis_analysis
 
 
 
1529
 
1530
  def _validate_and_correct_action(self, action: Action, case_state: CaseState, remaining_budget: int) -> Action:
1531
  """Validate and correct actions based on mode constraints and context"""
@@ -1543,11 +1735,17 @@ CURRENT STATE:
1543
  action.content = case_state.get_leading_diagnosis()
1544
  action.reasoning = "Budget constraint: insufficient funds for additional testing"
1545
 
1546
- # Stagnation handling
1547
  if case_state.is_stagnating(action):
1548
  logger.warning("Stagnation detected, forcing diagnostic decision")
1549
  action.action_type = "diagnose"
1550
- action.content = case_state.get_leading_diagnosis()
 
 
 
 
 
 
1551
  action.reasoning = "Forced diagnosis due to detected stagnation in diagnostic process"
1552
 
1553
  # High confidence threshold
@@ -1583,7 +1781,7 @@ CURRENT STATE:
1583
  {request}
1584
  """
1585
 
1586
- response = gatekeeper.run(prompt)
1587
  return response
1588
 
1589
  def _judge_diagnosis(
@@ -1600,7 +1798,7 @@ CURRENT STATE:
1600
  Score: [number from 1-5]
1601
  Justification: [detailed reasoning for the score]
1602
  """
1603
- response = judge.run(prompt)
1604
 
1605
  # Handle different response types from swarms Agent
1606
  response_text = ""
@@ -2013,7 +2211,10 @@ CURRENT STATE:
2013
  print_on=True, # Enable printing for aggregator agent
2014
  )
2015
 
2016
- return aggregator.run(aggregator_prompt).strip()
 
 
 
2017
 
2018
  except Exception as e:
2019
  logger.error(f"Error in ensemble aggregation: {e}")
@@ -2072,9 +2273,283 @@ CURRENT STATE:
2072
 
2073
  config = variant_configs[variant]
2074
  config.update(kwargs) # Allow overrides
 
 
 
2075
 
2076
  return cls(**config)
2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2078
 
2079
  def run_mai_dxo_demo(
2080
  case_info: str = None,
@@ -2129,13 +2604,13 @@ def run_mai_dxo_demo(
2129
  orchestrator = MaiDxOrchestrator.create_variant(
2130
  variant,
2131
  budget=3000,
2132
- model_name="gpt-4-1106-preview", # Fixed: Use valid model name
2133
  max_iterations=5,
2134
  )
2135
  else:
2136
  orchestrator = MaiDxOrchestrator.create_variant(
2137
  variant,
2138
- model_name="gpt-4-1106-preview", # Fixed: Use valid model name
2139
  max_iterations=5,
2140
  )
2141
 
@@ -2151,159 +2626,159 @@ def run_mai_dxo_demo(
2151
  return results
2152
 
2153
 
2154
- if __name__ == "__main__":
2155
- # Example case inspired by the paper's Figure 1
2156
- initial_info = (
2157
- "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
2158
- "and bleeding. Symptoms did not abate with antimicrobial therapy."
2159
- )
2160
-
2161
- full_case = """
2162
- Patient: 29-year-old female.
2163
- History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
2164
- No fevers, headaches, or gastrointestinal symptoms. Past medical history is unremarkable. No history of smoking or significant alcohol use.
2165
- Physical Exam: Right peritonsillar mass, displacing the uvula. No other significant findings.
2166
- Initial Labs: FBC, clotting studies normal.
2167
- MRI Neck: Showed a large, enhancing mass in the right peritonsillar space.
2168
- Biopsy (H&E): Infiltrative round-cell neoplasm with high nuclear-to-cytoplasmic ratio and frequent mitotic figures.
2169
- Biopsy (Immunohistochemistry for Carcinoma): CD31, D2-40, CD34, ERG, GLUT-1, pan-cytokeratin, CD45, CD20, CD3 all negative. Ki-67: 60% nuclear positivity.
2170
- Biopsy (Immunohistochemistry for Rhabdomyosarcoma): Desmin and MyoD1 diffusely positive. Myogenin multifocally positive.
2171
- Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
2172
- Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
2173
- """
2174
-
2175
- ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
2176
-
2177
- # --- Demonstrate Different MAI-DxO Variants ---
2178
- try:
2179
- print("\n" + "=" * 80)
2180
- print(
2181
- " MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK"
2182
- )
2183
- print(
2184
- " Implementation based on the NEJM Research Paper"
2185
- )
2186
- print("=" * 80)
2187
-
2188
- # Test different variants as described in the paper
2189
- variants_to_test = [
2190
- (
2191
- "no_budget",
2192
- "Standard MAI-DxO with no budget constraints",
2193
- ),
2194
- ("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
2195
- (
2196
- "question_only",
2197
- "Question-only variant (no diagnostic tests)",
2198
- ),
2199
- ]
2200
-
2201
- results = {}
2202
-
2203
- for variant_name, description in variants_to_test:
2204
- print(f"\n{'='*60}")
2205
- print(f"Testing Variant: {variant_name.upper()}")
2206
- print(f"Description: {description}")
2207
- print("=" * 60)
2208
-
2209
- # Create the variant
2210
- if variant_name == "budgeted":
2211
- orchestrator = MaiDxOrchestrator.create_variant(
2212
- variant_name,
2213
- budget=3000,
2214
- model_name="gpt-4-1106-preview", # Fixed: Use valid model name
2215
- max_iterations=5,
2216
- )
2217
- else:
2218
- orchestrator = MaiDxOrchestrator.create_variant(
2219
- variant_name,
2220
- model_name="gpt-4-1106-preview", # Fixed: Use valid model name
2221
- max_iterations=5,
2222
- )
2223
-
2224
- # Run the diagnostic process
2225
- result = orchestrator.run(
2226
- initial_case_info=initial_info,
2227
- full_case_details=full_case,
2228
- ground_truth_diagnosis=ground_truth,
2229
- )
2230
-
2231
- results[variant_name] = result
2232
-
2233
- # Display results
2234
- print(f"\nπŸš€ Final Diagnosis: {result.final_diagnosis}")
2235
- print(f"🎯 Ground Truth: {result.ground_truth}")
2236
- print(f"⭐ Accuracy Score: {result.accuracy_score}/5.0")
2237
- print(f" Reasoning: {result.accuracy_reasoning}")
2238
- print(f"πŸ’° Total Cost: ${result.total_cost:,}")
2239
- print(f"πŸ”„ Iterations: {result.iterations}")
2240
- print(f"⏱️ Mode: {orchestrator.mode}")
2241
-
2242
- # Demonstrate ensemble approach
2243
- print(f"\n{'='*60}")
2244
- print("Testing Variant: ENSEMBLE")
2245
- print(
2246
- "Description: Multiple independent runs with consensus aggregation"
2247
- )
2248
- print("=" * 60)
2249
-
2250
- ensemble_orchestrator = MaiDxOrchestrator.create_variant(
2251
- "ensemble",
2252
- model_name="gpt-4-1106-preview", # Fixed: Use valid model name
2253
- max_iterations=3, # Shorter iterations for ensemble
2254
- )
2255
-
2256
- ensemble_result = ensemble_orchestrator.run_ensemble(
2257
- initial_case_info=initial_info,
2258
- full_case_details=full_case,
2259
- ground_truth_diagnosis=ground_truth,
2260
- num_runs=2, # Reduced for demo
2261
- )
2262
-
2263
- results["ensemble"] = ensemble_result
2264
-
2265
- print(
2266
- f"\nπŸš€ Ensemble Diagnosis: {ensemble_result.final_diagnosis}"
2267
- )
2268
- print(f"🎯 Ground Truth: {ensemble_result.ground_truth}")
2269
- print(
2270
- f"⭐ Ensemble Score: {ensemble_result.accuracy_score}/5.0"
2271
- )
2272
- print(
2273
- f"πŸ’° Total Ensemble Cost: ${ensemble_result.total_cost:,}"
2274
- )
2275
-
2276
- # --- Summary Comparison ---
2277
- print(f"\n{'='*80}")
2278
- print(" RESULTS SUMMARY")
2279
- print("=" * 80)
2280
- print(
2281
- f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}"
2282
- )
2283
- print("-" * 80)
2284
-
2285
- for variant_name, result in results.items():
2286
- match_status = (
2287
- "βœ“ Match"
2288
- if result.accuracy_score >= 4.0
2289
- else "βœ— No Match"
2290
- )
2291
- print(
2292
- f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}"
2293
- )
2294
-
2295
- print(f"\n{'='*80}")
2296
- print(
2297
- "Implementation successfully demonstrates the MAI-DxO framework"
2298
- )
2299
- print(
2300
- "as described in 'Sequential Diagnosis with Language Models' paper"
2301
- )
2302
- print("=" * 80)
2303
-
2304
- except Exception as e:
2305
- logger.exception(
2306
- f"An error occurred during the diagnostic session: {e}"
2307
- )
2308
- print(f"\n❌ Error occurred: {e}")
2309
- print("Please check your model configuration and API keys.")
 
182
  retry_count: int = 0
183
 
184
  def to_consensus_prompt(self) -> str:
185
+ """Generate a structured prompt for the consensus coordinator - no truncation, let agent self-regulate"""
186
+
187
  prompt = f"""
188
+ You are the Consensus Coordinator. Here is the panel's analysis:
189
 
190
+ **Differential Diagnosis (Dr. Hypothesis):**
191
+ {self.hypothesis_analysis or 'Not yet formulated'}
192
 
193
+ **Test Recommendations (Dr. Test-Chooser):**
194
+ {self.test_chooser_analysis or 'None provided'}
195
 
196
+ **Critical Challenges (Dr. Challenger):**
197
+ {self.challenger_analysis or 'None identified'}
198
 
199
+ **Cost Assessment (Dr. Stewardship):**
200
+ {self.stewardship_analysis or 'Not evaluated'}
201
 
202
+ **Quality Control (Dr. Checklist):**
203
+ {self.checklist_analysis or 'No issues noted'}
204
  """
205
 
206
  if self.stagnation_detected:
207
+ prompt += "\n**STAGNATION DETECTED** - The panel is repeating actions. You MUST make a decisive choice or provide final diagnosis."
 
 
 
 
208
 
209
  if self.situational_context:
210
+ prompt += f"\n**Situational Context:** {self.situational_context}"
 
 
 
211
 
212
+ prompt += "\n\nBased on this comprehensive panel input, use the make_consensus_decision function to provide your structured action."
213
  return prompt
214
 
215
 
 
255
 
256
  def __init__(
257
  self,
258
+ model_name: str = "gpt-4.1", # Fixed: Use valid GPT-4 Turbo model name
259
  max_iterations: int = 10,
260
  initial_budget: int = 10000,
261
  mode: str = "no_budget", # "instant", "question_only", "budgeted", "no_budget", "ensemble"
262
  physician_visit_cost: int = 300,
263
  enable_budget_tracking: bool = False,
264
+ request_delay: float = 8.0, # seconds to wait between model calls to mitigate rate-limits
265
  ):
266
  """
267
  Initializes the MAI-DxO system with improved architecture.
 
273
  mode (str): The operational mode of MAI-DxO.
274
  physician_visit_cost (int): Cost per physician visit.
275
  enable_budget_tracking (bool): Whether to enable budget tracking.
276
+ request_delay (float): Seconds to wait between model calls to mitigate rate-limits.
277
  """
278
  self.model_name = model_name
279
  self.max_iterations = max_iterations
 
282
  self.physician_visit_cost = physician_visit_cost
283
  self.enable_budget_tracking = enable_budget_tracking
284
 
285
+ # Throttle settings to avoid OpenAI TPM rate-limits
286
+ self.request_delay = max(request_delay, 0)
287
+
288
+ # Token management
289
+ self.max_total_tokens_per_request = 25000 # Safety margin below 30k limit
290
+
291
  self.cumulative_cost = 0
292
  self.differential_diagnosis = "Not yet formulated."
293
  self.conversation = Conversation(
 
334
  )
335
 
336
  def _get_agent_max_tokens(self, role: AgentRole) -> int:
337
+ """Get max_tokens for each agent based on their role - agents will self-regulate based on token guidance"""
338
  token_limits = {
339
+ # Reasonable limits - agents will adjust their verbosity based on token guidance
340
+ AgentRole.HYPOTHESIS: 1200, # Function calling keeps this structured, but allow room for quality
341
+ AgentRole.TEST_CHOOSER: 800, # Need space for test rationale
342
+ AgentRole.CHALLENGER: 800, # Need space for critical analysis
343
+ AgentRole.STEWARDSHIP: 600,
344
+ AgentRole.CHECKLIST: 400,
345
+ AgentRole.CONSENSUS: 500, # Function calling is efficient
346
+ AgentRole.GATEKEEPER: 1000, # Needs to provide detailed clinical findings
347
+ AgentRole.JUDGE: 700,
348
  }
349
+ return token_limits.get(role, 600)
350
+
351
+ def _estimate_tokens(self, text: str) -> int:
352
+ """Rough token estimation (1 token β‰ˆ 4 characters for English)"""
353
+ return len(text) // 4
354
+
355
+ def _generate_token_guidance(self, input_tokens: int, max_output_tokens: int, total_tokens: int, agent_role: AgentRole) -> str:
356
+ """Generate dynamic token guidance for agents to self-regulate their responses"""
357
+
358
+ # Determine urgency level based on token usage
359
+ if total_tokens > self.max_total_tokens_per_request:
360
+ urgency = "CRITICAL"
361
+ strategy = "Be extremely concise. Prioritize only the most essential information."
362
+ elif total_tokens > self.max_total_tokens_per_request * 0.8:
363
+ urgency = "HIGH"
364
+ strategy = "Be concise and focus on key points. Avoid elaborate explanations."
365
+ elif total_tokens > self.max_total_tokens_per_request * 0.6:
366
+ urgency = "MODERATE"
367
+ strategy = "Be reasonably concise while maintaining necessary detail."
368
+ else:
369
+ urgency = "LOW"
370
+ strategy = "You can provide detailed analysis within your allocated tokens."
371
+
372
+ # Role-specific guidance
373
+ role_specific_guidance = {
374
+ AgentRole.HYPOTHESIS: "Focus on top 2-3 diagnoses with probabilities. Prioritize summary over detailed pathophysiology.",
375
+ AgentRole.TEST_CHOOSER: "Recommend 1-2 highest-yield tests. Focus on which hypotheses they'll help differentiate.",
376
+ AgentRole.CHALLENGER: "Identify 1-2 most critical biases or alternative diagnoses. Be direct and specific.",
377
+ AgentRole.STEWARDSHIP: "Focus on cost-effectiveness assessment. Recommend cheaper alternatives where applicable.",
378
+ AgentRole.CHECKLIST: "Provide concise quality check. Flag critical issues only.",
379
+ AgentRole.CONSENSUS: "Function calling enforces structure. Focus on clear reasoning.",
380
+ AgentRole.GATEKEEPER: "Provide specific clinical findings. Be factual and complete but not verbose.",
381
+ AgentRole.JUDGE: "Provide score and focused justification. Be systematic but concise."
382
+ }.get(agent_role, "Be concise and focused.")
383
+
384
+ guidance = f"""
385
+ [TOKEN MANAGEMENT - {urgency} PRIORITY]
386
+ Input: {input_tokens} tokens | Your Output Limit: {max_output_tokens} tokens | Total: {total_tokens} tokens
387
+ Strategy: {strategy}
388
+ Role Focus: {role_specific_guidance}
389
+
390
+ IMPORTANT: Adjust your response length and detail level based on this guidance. Prioritize the most critical information for your role.
391
+ """
392
+
393
+ return guidance
394
 
395
  def _init_agents(self) -> None:
396
  """Initializes all required agents with their specific roles and prompts."""
397
+
398
+ # Define the structured output tool for consensus decisions
399
+ consensus_tool = {
400
+ "type": "function",
401
+ "function": {
402
+ "name": "make_consensus_decision",
403
+ "description": "Make a structured consensus decision for the next diagnostic action",
404
+ "parameters": {
405
+ "type": "object",
406
+ "properties": {
407
+ "action_type": {
408
+ "type": "string",
409
+ "enum": ["ask", "test", "diagnose"],
410
+ "description": "The type of action to perform"
411
+ },
412
+ "content": {
413
+ "type": "string",
414
+ "description": "The specific content of the action (question, test name, or diagnosis)"
415
+ },
416
+ "reasoning": {
417
+ "type": "string",
418
+ "description": "The detailed reasoning behind this decision, synthesizing panel input"
419
+ }
420
+ },
421
+ "required": ["action_type", "content", "reasoning"]
422
+ }
423
+ }
424
+ }
425
+
426
+ # Define structured output tool for differential diagnosis
427
+ hypothesis_tool = {
428
+ "type": "function",
429
+ "function": {
430
+ "name": "update_differential_diagnosis",
431
+ "description": "Update the differential diagnosis with structured probabilities and reasoning",
432
+ "parameters": {
433
+ "type": "object",
434
+ "properties": {
435
+ "summary": {
436
+ "type": "string",
437
+ "description": "One-sentence summary of primary diagnostic conclusion and confidence"
438
+ },
439
+ "differential_diagnoses": {
440
+ "type": "array",
441
+ "items": {
442
+ "type": "object",
443
+ "properties": {
444
+ "diagnosis": {"type": "string", "description": "The diagnosis name"},
445
+ "probability": {"type": "number", "minimum": 0, "maximum": 1, "description": "Probability as decimal (0.0-1.0)"},
446
+ "rationale": {"type": "string", "description": "Brief rationale for this diagnosis"}
447
+ },
448
+ "required": ["diagnosis", "probability", "rationale"]
449
+ },
450
+ "minItems": 2,
451
+ "maxItems": 5,
452
+ "description": "Top 2-5 differential diagnoses with probabilities"
453
+ },
454
+ "key_evidence": {
455
+ "type": "string",
456
+ "description": "Key supporting evidence for leading hypotheses"
457
+ },
458
+ "contradictory_evidence": {
459
+ "type": "string",
460
+ "description": "Critical contradictory evidence that must be addressed"
461
+ }
462
+ },
463
+ "required": ["summary", "differential_diagnoses", "key_evidence"]
464
+ }
465
+ }
466
  }
467
+
468
+ self.agents = {}
469
+ for role in AgentRole:
470
+ if role == AgentRole.CONSENSUS:
471
+ # Use function calling for consensus agent to ensure structured output
472
+ self.agents[role] = Agent(
473
+ agent_name=role.value,
474
+ system_prompt=self._get_prompt_for_role(role),
475
+ model_name=self.model_name,
476
+ max_loops=1,
477
+ tools_list_dictionary=[consensus_tool], # swarms expects tools_list_dictionary
478
+ tool_choice="auto", # Let the model choose to use the tool
479
+ print_on=True,
480
+ max_tokens=self._get_agent_max_tokens(role),
481
+ )
482
+ elif role == AgentRole.HYPOTHESIS:
483
+ # Use function calling for hypothesis agent to ensure structured differential
484
+ self.agents[role] = Agent(
485
+ agent_name=role.value,
486
+ system_prompt=self._get_prompt_for_role(role),
487
+ model_name=self.model_name,
488
+ max_loops=1,
489
+ tools_list_dictionary=[hypothesis_tool],
490
+ tool_choice="auto",
491
+ print_on=True,
492
+ max_tokens=self._get_agent_max_tokens(role),
493
+ )
494
+ else:
495
+ # Regular agents without function calling
496
+ self.agents[role] = Agent(
497
+ agent_name=role.value,
498
+ system_prompt=self._get_prompt_for_role(role),
499
+ model_name=self.model_name,
500
+ max_loops=1,
501
+ output_type="str",
502
+ print_on=True,
503
+ max_tokens=self._get_agent_max_tokens(role),
504
+ )
505
+
506
  logger.info(
507
  f"πŸ‘₯ {len(self.agents)} virtual physician agents initialized and ready for consultation"
508
  )
 
566
  3. Always explain your Bayesian reasoning clearly
567
  4. Consider epidemiology, pathophysiology, and clinical patterns
568
 
569
+ **IMPORTANT: You MUST use the update_differential_diagnosis function to provide your structured analysis.**
570
+
571
+ Use the function to provide:
572
+ - A one-sentence summary of your primary diagnostic conclusion and confidence level
573
+ - Your top 2-5 differential diagnoses with probability estimates (as decimals: 0.0-1.0)
574
+ - Brief rationale for each diagnosis
575
+ - Key supporting evidence for leading hypotheses
576
+ - Critical contradictory evidence that must be addressed
577
+
578
+ Remember: Your differential drives the entire diagnostic process. Provide clear probabilities and reasoning.
579
  """,
580
 
581
  AgentRole.TEST_CHOOSER: f"""
 
604
  - Avoid redundant tests that won't add new information
605
  - Consider pre-test probability and post-test probability calculations
606
 
607
+ OUTPUT FORMAT (You have a response limit of {self._get_agent_max_tokens(AgentRole.TEST_CHOOSER)} tokens - prioritize actionable recommendations):
608
+
609
+ **SUMMARY FIRST:** Lead with your single most recommended test and why it's the highest priority.
610
+
611
+ **DETAILED RECOMMENDATIONS (up to 3 tests):**
612
+ For each test:
613
  - Test name (be specific and accurate)
614
  - Primary hypotheses it will help evaluate
615
+ - Expected information gain
616
+ - How results will change management
617
+ - Cost-effectiveness assessment
618
+ - Timing rationale
 
619
 
620
+ Focus on tests that will most efficiently narrow the differential diagnosis.
621
  """,
622
 
623
  AgentRole.CHALLENGER: f"""
 
648
  - Advocate for considering multiple conditions simultaneously
649
  - Look for inconsistencies in the clinical presentation
650
 
651
+ OUTPUT FORMAT (You have a response limit of {self._get_agent_max_tokens(AgentRole.CHALLENGER)} tokens - focus on the most critical challenges):
652
+
653
+ **SUMMARY FIRST:** State your primary concern with the current diagnostic approach in one sentence.
654
+
655
+ **CRITICAL CHALLENGES:**
656
+ - Most significant bias identified in current reasoning
657
+ - Key evidence that contradicts leading hypotheses
658
+ - Most important alternative diagnosis to consider
659
+ - Essential test to falsify current assumptions
660
+ - Highest priority red flag or safety concern
661
+ - Most critical gap in current approach
662
+
663
+ Be constructively critical - focus on the challenges that most impact diagnostic accuracy.
664
  """,
665
 
666
  AgentRole.STEWARDSHIP: f"""
 
768
  4. **Cost Optimization:** Before finalizing a test, check Dr. Stewardship's input. If a diagnostically equivalent but cheaper alternative is available, select it.
769
  5. **Default to Questions:** If no test meets the criteria or the budget is a major concern, select the most pertinent question to ask.
770
 
771
+ **IMPORTANT: You MUST use the make_consensus_decision function to provide your structured response. Call this function with the appropriate action_type, content, and reasoning parameters.**
772
+
 
 
 
 
 
773
  For action_type "ask": content should be specific patient history or physical exam questions
774
  For action_type "test": content should be properly named diagnostic tests (up to 3)
775
  For action_type "diagnose": content should be the complete, specific final diagnosis
776
 
777
+ Make the decision that best advances accurate, cost-effective diagnosis. Use comprehensive reasoning that synthesizes all panel input and cites the specific decision framework step you're following.
778
  """,
779
 
780
  AgentRole.GATEKEEPER: f"""
 
1493
  # Initialize structured deliberation state instead of conversational chaining
1494
  deliberation_state = DeliberationState()
1495
 
1496
+ # Prepare concise case context for each agent (token-optimized)
1497
  remaining_budget = self.initial_budget - case_state.cumulative_cost
1498
  budget_status = (
1499
  "EXCEEDED"
 
1501
  else f"${remaining_budget:,}"
1502
  )
1503
 
1504
+ # Full context - let agents self-regulate based on token guidance
1505
  base_context = f"""
1506
  === DIAGNOSTIC CASE STATUS - ROUND {case_state.iteration} ===
1507
 
 
1549
  # Dr. Hypothesis - Differential diagnosis and probability assessment
1550
  logger.info("🧠 Dr. Hypothesis analyzing differential diagnosis...")
1551
  hypothesis_prompt = self._get_prompt_for_role(AgentRole.HYPOTHESIS, case_state) + "\n\n" + base_context
1552
+ hypothesis_response = self._safe_agent_run(
1553
+ self.agents[AgentRole.HYPOTHESIS], hypothesis_prompt, agent_role=AgentRole.HYPOTHESIS
1554
+ )
1555
+
1556
+ # Update case state with new differential (supports both function calls and text)
1557
+ self._update_differential_from_hypothesis(case_state, hypothesis_response)
1558
 
1559
+ # Store the analysis for deliberation state (convert to text format for other agents)
1560
+ if hasattr(hypothesis_response, 'content'):
1561
+ deliberation_state.hypothesis_analysis = hypothesis_response.content
1562
+ else:
1563
+ deliberation_state.hypothesis_analysis = str(hypothesis_response)
1564
 
1565
  # Dr. Test-Chooser - Information value optimization
1566
  logger.info("πŸ”¬ Dr. Test-Chooser selecting optimal tests...")
1567
  test_chooser_prompt = self._get_prompt_for_role(AgentRole.TEST_CHOOSER, case_state) + "\n\n" + base_context
1568
  if self.mode == "question_only":
1569
  test_chooser_prompt += "\n\nIMPORTANT: This is QUESTION-ONLY mode. You may ONLY recommend patient questions, not diagnostic tests."
1570
+ deliberation_state.test_chooser_analysis = self._safe_agent_run(
1571
+ self.agents[AgentRole.TEST_CHOOSER], test_chooser_prompt, agent_role=AgentRole.TEST_CHOOSER
1572
+ )
1573
 
1574
  # Dr. Challenger - Bias identification and alternative hypotheses
1575
  logger.info("πŸ€” Dr. Challenger challenging assumptions...")
1576
  challenger_prompt = self._get_prompt_for_role(AgentRole.CHALLENGER, case_state) + "\n\n" + base_context
1577
+ deliberation_state.challenger_analysis = self._safe_agent_run(
1578
+ self.agents[AgentRole.CHALLENGER], challenger_prompt, agent_role=AgentRole.CHALLENGER
1579
+ )
1580
 
1581
  # Dr. Stewardship - Cost-effectiveness analysis
1582
  logger.info("πŸ’° Dr. Stewardship evaluating cost-effectiveness...")
1583
  stewardship_prompt = self._get_prompt_for_role(AgentRole.STEWARDSHIP, case_state) + "\n\n" + base_context
1584
  if self.enable_budget_tracking:
1585
  stewardship_prompt += f"\n\nBUDGET TRACKING ENABLED - Current cost: ${case_state.cumulative_cost}, Remaining: ${remaining_budget}"
1586
+ deliberation_state.stewardship_analysis = self._safe_agent_run(
1587
+ self.agents[AgentRole.STEWARDSHIP], stewardship_prompt, agent_role=AgentRole.STEWARDSHIP
1588
+ )
1589
 
1590
  # Dr. Checklist - Quality assurance
1591
  logger.info("βœ… Dr. Checklist performing quality control...")
1592
  checklist_prompt = self._get_prompt_for_role(AgentRole.CHECKLIST, case_state) + "\n\n" + base_context
1593
+ deliberation_state.checklist_analysis = self._safe_agent_run(
1594
+ self.agents[AgentRole.CHECKLIST], checklist_prompt, agent_role=AgentRole.CHECKLIST
1595
+ )
1596
 
1597
  # Consensus Coordinator - Final decision synthesis using structured state
1598
  logger.info("🀝 Consensus Coordinator synthesizing panel decision...")
 
1604
  if self.mode == "budgeted" and remaining_budget <= 0:
1605
  consensus_prompt += "\n\nBUDGET CONSTRAINT: Budget exceeded - must either ask questions or provide final diagnosis."
1606
 
1607
+ # Use function calling with retry logic for robust structured output
1608
+ action_dict = self._get_consensus_with_retry(consensus_prompt)
 
 
 
1609
 
1610
  # Validate action based on mode constraints
1611
  action = Action(**action_dict)
 
1653
 
1654
  return " | ".join(context_parts) if context_parts else ""
1655
 
1656
+ def _update_differential_from_hypothesis(self, case_state: CaseState, hypothesis_response):
1657
+ """Extract and update differential diagnosis from Dr. Hypothesis analysis - now supports both function calls and text"""
1658
  try:
1659
+ # Try to extract structured data from function call first
1660
+ if hasattr(hypothesis_response, '__dict__') or isinstance(hypothesis_response, dict):
1661
+ structured_data = self._extract_function_call_output(hypothesis_response)
1662
+
1663
+ # Check if we got structured differential data
1664
+ if "differential_diagnoses" in structured_data:
1665
+ # Update case state with structured data
1666
+ new_differential = {}
1667
+ for dx in structured_data["differential_diagnoses"]:
1668
+ new_differential[dx["diagnosis"]] = dx["probability"]
1669
+
1670
+ case_state.update_differential(new_differential)
1671
+
1672
+ # Update the main differential for backward compatibility
1673
+ summary = structured_data.get("summary", "Differential diagnosis updated")
1674
+ dx_text = f"{summary}\n\nTop Diagnoses:\n"
1675
+ for dx in structured_data["differential_diagnoses"]:
1676
+ dx_text += f"- {dx['diagnosis']}: {dx['probability']:.0%} - {dx['rationale']}\n"
1677
+
1678
+ if "key_evidence" in structured_data:
1679
+ dx_text += f"\nKey Evidence: {structured_data['key_evidence']}"
1680
+ if "contradictory_evidence" in structured_data:
1681
+ dx_text += f"\nContradictory Evidence: {structured_data['contradictory_evidence']}"
1682
+
1683
+ self.differential_diagnosis = dx_text
1684
+ logger.debug(f"Updated differential from function call: {new_differential}")
1685
+ return
1686
+
1687
+ # Fallback to text-based extraction
1688
+ hypothesis_text = str(hypothesis_response)
1689
+ if hasattr(hypothesis_response, 'content'):
1690
+ hypothesis_text = hypothesis_response.content
1691
+
1692
  # Simple extraction - look for percentage patterns in the text
1693
  import re
1694
 
1695
  # Update the main differential diagnosis for backward compatibility
1696
+ self.differential_diagnosis = hypothesis_text
1697
 
1698
  # Try to extract structured probabilities
1699
  # Look for patterns like "Diagnosis: 85%" or "Disease (70%)"
1700
  percentage_pattern = r'([A-Za-z][^:(\n]*?)[\s:]*[\(]?(\d{1,3})%[\)]?'
1701
+ matches = re.findall(percentage_pattern, hypothesis_text)
1702
 
1703
  new_differential = {}
1704
  for match in matches:
 
1709
 
1710
  if new_differential:
1711
  case_state.update_differential(new_differential)
1712
+ logger.debug(f"Updated differential from text parsing: {new_differential}")
1713
 
1714
  except Exception as e:
1715
  logger.debug(f"Could not extract structured differential: {e}")
1716
  # Still update the text version for display
1717
+ hypothesis_text = str(hypothesis_response)
1718
+ if hasattr(hypothesis_response, 'content'):
1719
+ hypothesis_text = hypothesis_response.content
1720
+ self.differential_diagnosis = hypothesis_text
1721
 
1722
  def _validate_and_correct_action(self, action: Action, case_state: CaseState, remaining_budget: int) -> Action:
1723
  """Validate and correct actions based on mode constraints and context"""
 
1735
  action.content = case_state.get_leading_diagnosis()
1736
  action.reasoning = "Budget constraint: insufficient funds for additional testing"
1737
 
1738
+ # Stagnation handling - ensure we have a valid diagnosis
1739
  if case_state.is_stagnating(action):
1740
  logger.warning("Stagnation detected, forcing diagnostic decision")
1741
  action.action_type = "diagnose"
1742
+ leading_diagnosis = case_state.get_leading_diagnosis()
1743
+ # Ensure the diagnosis is meaningful, not corrupted
1744
+ if leading_diagnosis == "No diagnosis formulated" or len(leading_diagnosis) < 10 or any(char in leading_diagnosis for char in ['x10^9', '–', '40–']):
1745
+ # Use a fallback diagnosis based on the case context
1746
+ action.content = "Unable to establish definitive diagnosis - further evaluation needed"
1747
+ else:
1748
+ action.content = leading_diagnosis
1749
  action.reasoning = "Forced diagnosis due to detected stagnation in diagnostic process"
1750
 
1751
  # High confidence threshold
 
1781
  {request}
1782
  """
1783
 
1784
+ response = self._safe_agent_run(gatekeeper, prompt, agent_role=AgentRole.GATEKEEPER)
1785
  return response
1786
 
1787
  def _judge_diagnosis(
 
1798
  Score: [number from 1-5]
1799
  Justification: [detailed reasoning for the score]
1800
  """
1801
+ response = self._safe_agent_run(judge, prompt, agent_role=AgentRole.JUDGE)
1802
 
1803
  # Handle different response types from swarms Agent
1804
  response_text = ""
 
2211
  print_on=True, # Enable printing for aggregator agent
2212
  )
2213
 
2214
+ agg_resp = self._safe_agent_run(aggregator, aggregator_prompt)
2215
+ if hasattr(agg_resp, "content"):
2216
+ return agg_resp.content.strip()
2217
+ return str(agg_resp).strip()
2218
 
2219
  except Exception as e:
2220
  logger.error(f"Error in ensemble aggregation: {e}")
 
2273
 
2274
  config = variant_configs[variant]
2275
  config.update(kwargs) # Allow overrides
2276
+
2277
+ # Remove 'budget' parameter if present, as it's mapped to 'initial_budget'
2278
+ config.pop('budget', None)
2279
 
2280
  return cls(**config)
2281
 
2282
+ # ------------------------------------------------------------------
2283
+ # Helper utilities – throttling & robust JSON parsing
2284
+ # ------------------------------------------------------------------
2285
+
2286
+ def _safe_agent_run(
2287
+ self,
2288
+ agent: "Agent", # type: ignore – forward reference
2289
+ prompt: str,
2290
+ retries: int = 3,
2291
+ agent_role: AgentRole = None,
2292
+ ) -> Any:
2293
+ """Safely call `agent.run` while respecting OpenAI rate-limits.
2294
+
2295
+ Features:
2296
+ 1. Estimates token usage and provides guidance to agents for self-regulation
2297
+ 2. Applies progressive delays to respect rate limits
2298
+ 3. Lets agents dynamically adjust their response strategy based on token constraints
2299
+ """
2300
+
2301
+ # Get agent role for token calculations
2302
+ if agent_role is None:
2303
+ agent_role = AgentRole.CONSENSUS # Default fallback
2304
+
2305
+ # Estimate total tokens in the request
2306
+ estimated_input_tokens = self._estimate_tokens(prompt)
2307
+ max_output_tokens = self._get_agent_max_tokens(agent_role)
2308
+ total_estimated_tokens = estimated_input_tokens + max_output_tokens
2309
+
2310
+ # Add dynamic token guidance to the prompt instead of truncating
2311
+ token_guidance = self._generate_token_guidance(
2312
+ estimated_input_tokens, max_output_tokens, total_estimated_tokens, agent_role
2313
+ )
2314
+
2315
+ # Prepend token guidance to prompt
2316
+ enhanced_prompt = f"{token_guidance}\n\n{prompt}"
2317
+
2318
+ logger.debug(f"Agent {agent_role.value}: Input={estimated_input_tokens}, Output={max_output_tokens}, Total={total_estimated_tokens}")
2319
+
2320
+ # Increased base delay for better rate limit compliance
2321
+ base_delay = max(self.request_delay, 5.0) # Minimum 5 seconds between requests
2322
+
2323
+ for attempt in range(retries + 1):
2324
+ # Progressive delay: 5s, 15s, 45s, 135s
2325
+ current_delay = base_delay * (3 ** attempt) if attempt > 0 else base_delay
2326
+
2327
+ logger.info(f"Request attempt {attempt + 1}/{retries + 1}, waiting {current_delay:.1f}s...")
2328
+ time.sleep(current_delay)
2329
+
2330
+ try:
2331
+ return agent.run(enhanced_prompt)
2332
+ except Exception as e:
2333
+ err_msg = str(e).lower()
2334
+ if "rate_limit" in err_msg or "ratelimiterror" in err_msg or "429" in str(e):
2335
+ logger.warning(
2336
+ f"Rate-limit encountered (attempt {attempt + 1}/{retries + 1}). "
2337
+ f"Will retry after {base_delay * (3 ** (attempt + 1)):.1f}s..."
2338
+ )
2339
+ continue # Next retry applies longer delay
2340
+ # For non-rate-limit errors, propagate immediately
2341
+ raise
2342
+
2343
+ # All retries exhausted
2344
+ raise RuntimeError("Maximum retries exceeded for agent.run – aborting call")
2345
+
2346
+ def _robust_parse_action(self, raw_response: str) -> Dict[str, Any]:
2347
+ """Extract a JSON *action* object from `raw_response`.
2348
+
2349
+ The function tries multiple strategies and finally returns a default
2350
+ *ask* action if no valid JSON can be located.
2351
+ """
2352
+
2353
+ import json, re
2354
+
2355
+ # Strip common markdown fences
2356
+ if raw_response.strip().startswith("```"):
2357
+ segments = raw_response.split("```")
2358
+ for seg in segments:
2359
+ seg = seg.strip()
2360
+ if seg.startswith("{") and seg.endswith("}"):
2361
+ raw_response = seg
2362
+ break
2363
+
2364
+ # 1) Fast path – direct JSON decode
2365
+ try:
2366
+ data = json.loads(raw_response)
2367
+ if isinstance(data, dict) and "action_type" in data:
2368
+ return data
2369
+ except Exception:
2370
+ pass
2371
+
2372
+ # 2) Regex search for the first balanced curly block
2373
+ match = re.search(r"\{[\s\S]*?\}", raw_response)
2374
+ if match:
2375
+ candidate = match.group(0)
2376
+ # Remove leading drawing characters (e.g., table borders)
2377
+ candidate = "\n".join(line.lstrip("β”‚| ").rstrip("β”‚| ") for line in candidate.splitlines())
2378
+ try:
2379
+ data = json.loads(candidate)
2380
+ if isinstance(data, dict) and "action_type" in data:
2381
+ return data
2382
+ except Exception:
2383
+ pass
2384
+
2385
+ logger.error("Failed to parse a valid action JSON. Falling back to default ask action")
2386
+ return {
2387
+ "action_type": "ask",
2388
+ "content": "Could you please clarify the next best step? The previous analysis was inconclusive.",
2389
+ "reasoning": "Fallback generated due to JSON parsing failure.",
2390
+ }
2391
+
2392
+ def _extract_function_call_output(self, agent_response) -> Dict[str, Any]:
2393
+ """Extract structured output from agent function call response.
2394
+
2395
+ This method handles the swarms Agent response format when using function calling.
2396
+ The response should contain tool calls with the structured data.
2397
+ """
2398
+ try:
2399
+ # Handle different response formats from swarms Agent
2400
+ if isinstance(agent_response, dict):
2401
+ # Check for tool calls in the response
2402
+ if "tool_calls" in agent_response and agent_response["tool_calls"]:
2403
+ tool_call = agent_response["tool_calls"][0] # Get first tool call
2404
+ if "function" in tool_call and "arguments" in tool_call["function"]:
2405
+ arguments = tool_call["function"]["arguments"]
2406
+ if isinstance(arguments, str):
2407
+ # Parse JSON string arguments
2408
+ import json
2409
+ arguments = json.loads(arguments)
2410
+ return arguments
2411
+
2412
+ # Check for direct arguments in response
2413
+ if "arguments" in agent_response:
2414
+ arguments = agent_response["arguments"]
2415
+ if isinstance(arguments, str):
2416
+ import json
2417
+ arguments = json.loads(arguments)
2418
+ return arguments
2419
+
2420
+ # Check if response itself has the expected structure
2421
+ if all(key in agent_response for key in ["action_type", "content", "reasoning"]):
2422
+ return {
2423
+ "action_type": agent_response["action_type"],
2424
+ "content": agent_response["content"],
2425
+ "reasoning": agent_response["reasoning"]
2426
+ }
2427
+
2428
+ # Handle Agent object response
2429
+ elif hasattr(agent_response, "__dict__"):
2430
+ # Check for tool_calls attribute
2431
+ if hasattr(agent_response, "tool_calls") and agent_response.tool_calls:
2432
+ tool_call = agent_response.tool_calls[0]
2433
+ if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
2434
+ arguments = tool_call.function.arguments
2435
+ if isinstance(arguments, str):
2436
+ import json
2437
+ arguments = json.loads(arguments)
2438
+ return arguments
2439
+
2440
+ # Check for direct function call response
2441
+ if hasattr(agent_response, "function_call"):
2442
+ function_call = agent_response.function_call
2443
+ if hasattr(function_call, "arguments"):
2444
+ arguments = function_call.arguments
2445
+ if isinstance(arguments, str):
2446
+ import json
2447
+ arguments = json.loads(arguments)
2448
+ return arguments
2449
+
2450
+ # Try to extract from response content
2451
+ if hasattr(agent_response, "content"):
2452
+ content = agent_response.content
2453
+ if isinstance(content, dict) and all(key in content for key in ["action_type", "content", "reasoning"]):
2454
+ return content
2455
+
2456
+ # Handle string response (fallback to regex parsing)
2457
+ elif isinstance(agent_response, str):
2458
+ # Try to parse as JSON first
2459
+ try:
2460
+ import json
2461
+ parsed = json.loads(agent_response)
2462
+ if isinstance(parsed, dict) and all(key in parsed for key in ["action_type", "content", "reasoning"]):
2463
+ return parsed
2464
+ except:
2465
+ pass
2466
+
2467
+ # Fallback to regex extraction
2468
+ import re
2469
+ action_type_match = re.search(r'"action_type":\s*"(ask|test|diagnose)"', agent_response, re.IGNORECASE)
2470
+ content_match = re.search(r'"content":\s*"([^"]+)"', agent_response, re.IGNORECASE | re.DOTALL)
2471
+ reasoning_match = re.search(r'"reasoning":\s*"([^"]+)"', agent_response, re.IGNORECASE | re.DOTALL)
2472
+
2473
+ if action_type_match and content_match and reasoning_match:
2474
+ return {
2475
+ "action_type": action_type_match.group(1).lower(),
2476
+ "content": content_match.group(1).strip(),
2477
+ "reasoning": reasoning_match.group(1).strip()
2478
+ }
2479
+
2480
+ logger.warning(f"Could not extract function call output from response type: {type(agent_response)}")
2481
+ logger.debug(f"Response content: {str(agent_response)[:500]}...")
2482
+
2483
+ except Exception as e:
2484
+ logger.error(f"Error extracting function call output: {e}")
2485
+ logger.debug(f"Response: {str(agent_response)[:500]}...")
2486
+
2487
+ # Final fallback
2488
+ return {
2489
+ "action_type": "ask",
2490
+ "content": "Could you please provide more information to help guide the next diagnostic step?",
2491
+ "reasoning": "Fallback action due to function call parsing error."
2492
+ }
2493
+
2494
+ def _get_consensus_with_retry(self, consensus_prompt: str, max_retries: int = 2) -> Dict[str, Any]:
2495
+ """Get consensus decision with function call retry logic."""
2496
+
2497
+ for attempt in range(max_retries + 1):
2498
+ try:
2499
+ if attempt == 0:
2500
+ # First attempt - use original prompt
2501
+ response = self._safe_agent_run(
2502
+ self.agents[AgentRole.CONSENSUS], consensus_prompt, agent_role=AgentRole.CONSENSUS
2503
+ )
2504
+ else:
2505
+ # Retry with explicit function call instruction
2506
+ retry_prompt = f"""
2507
+ {consensus_prompt}
2508
+
2509
+ **CRITICAL: RETRY ATTEMPT {attempt}**
2510
+ Your previous response failed to use the required `make_consensus_decision` function.
2511
+ You MUST call the make_consensus_decision function with the appropriate parameters:
2512
+ - action_type: "ask", "test", or "diagnose"
2513
+ - content: specific question, test name, or diagnosis
2514
+ - reasoning: your detailed reasoning
2515
+
2516
+ Please try again and ensure you call the function correctly.
2517
+ """
2518
+ response = self._safe_agent_run(
2519
+ self.agents[AgentRole.CONSENSUS], retry_prompt, agent_role=AgentRole.CONSENSUS
2520
+ )
2521
+
2522
+ logger.debug(f"Consensus attempt {attempt + 1}, response type: {type(response)}")
2523
+
2524
+ # Try to extract function call output
2525
+ action_dict = self._extract_function_call_output(response)
2526
+
2527
+ # Check if we got a valid response (not a fallback)
2528
+ if not action_dict.get("reasoning", "").startswith("Fallback action due to function call parsing error"):
2529
+ logger.debug(f"Consensus function call successful on attempt {attempt + 1}")
2530
+ return action_dict
2531
+
2532
+ logger.warning(f"Function call failed on attempt {attempt + 1}, will retry")
2533
+
2534
+ except Exception as e:
2535
+ logger.error(f"Error in consensus attempt {attempt + 1}: {e}")
2536
+
2537
+ # Final fallback to JSON parsing if all function call attempts failed
2538
+ logger.warning("All function call attempts failed, falling back to JSON parsing")
2539
+ try:
2540
+ # Use the last response and try JSON parsing
2541
+ consensus_text = (
2542
+ response.content if hasattr(response, "content") else str(response)
2543
+ )
2544
+ return self._robust_parse_action(consensus_text)
2545
+ except Exception as e:
2546
+ logger.error(f"Both function calling and JSON parsing failed: {e}")
2547
+ return {
2548
+ "action_type": "ask",
2549
+ "content": "Could you please provide more information to guide the diagnostic process?",
2550
+ "reasoning": f"Final fallback after {max_retries + 1} function call attempts and JSON parsing failure."
2551
+ }
2552
+
2553
 
2554
  def run_mai_dxo_demo(
2555
  case_info: str = None,
 
2604
  orchestrator = MaiDxOrchestrator.create_variant(
2605
  variant,
2606
  budget=3000,
2607
+ model_name="gpt-4.1", # Fixed: Use valid model name
2608
  max_iterations=5,
2609
  )
2610
  else:
2611
  orchestrator = MaiDxOrchestrator.create_variant(
2612
  variant,
2613
+ model_name="gpt-4.1", # Fixed: Use valid model name
2614
  max_iterations=5,
2615
  )
2616
 
 
2626
  return results
2627
 
2628
 
2629
+ # if __name__ == "__main__":
2630
+ # # Example case inspired by the paper's Figure 1
2631
+ # initial_info = (
2632
+ # "A 29-year-old woman was admitted to the hospital because of sore throat and peritonsillar swelling "
2633
+ # "and bleeding. Symptoms did not abate with antimicrobial therapy."
2634
+ # )
2635
+
2636
+ # full_case = """
2637
+ # Patient: 29-year-old female.
2638
+ # History: Onset of sore throat 7 weeks prior to admission. Worsening right-sided pain and swelling.
2639
+ # No fevers, headaches, or gastrointestinal symptoms. Past medical history is unremarkable. No history of smoking or significant alcohol use.
2640
+ # Physical Exam: Right peritonsillar mass, displacing the uvula. No other significant findings.
2641
+ # Initial Labs: FBC, clotting studies normal.
2642
+ # MRI Neck: Showed a large, enhancing mass in the right peritonsillar space.
2643
+ # Biopsy (H&E): Infiltrative round-cell neoplasm with high nuclear-to-cytoplasmic ratio and frequent mitotic figures.
2644
+ # Biopsy (Immunohistochemistry for Carcinoma): CD31, D2-40, CD34, ERG, GLUT-1, pan-cytokeratin, CD45, CD20, CD3 all negative. Ki-67: 60% nuclear positivity.
2645
+ # Biopsy (Immunohistochemistry for Rhabdomyosarcoma): Desmin and MyoD1 diffusely positive. Myogenin multifocally positive.
2646
+ # Biopsy (FISH): No FOXO1 (13q14) rearrangements detected.
2647
+ # Final Diagnosis from Pathology: Embryonal rhabdomyosarcoma of the pharynx.
2648
+ # """
2649
+
2650
+ # ground_truth = "Embryonal rhabdomyosarcoma of the pharynx"
2651
+
2652
+ # # --- Demonstrate Different MAI-DxO Variants ---
2653
+ # try:
2654
+ # print("\n" + "=" * 80)
2655
+ # print(
2656
+ # " MAI DIAGNOSTIC ORCHESTRATOR (MAI-DxO) - SEQUENTIAL DIAGNOSIS BENCHMARK"
2657
+ # )
2658
+ # print(
2659
+ # " Implementation based on the NEJM Research Paper"
2660
+ # )
2661
+ # print("=" * 80)
2662
+
2663
+ # # Test different variants as described in the paper
2664
+ # variants_to_test = [
2665
+ # (
2666
+ # "no_budget",
2667
+ # "Standard MAI-DxO with no budget constraints",
2668
+ # ),
2669
+ # ("budgeted", "Budget-constrained MAI-DxO ($3000 limit)"),
2670
+ # (
2671
+ # "question_only",
2672
+ # "Question-only variant (no diagnostic tests)",
2673
+ # ),
2674
+ # ]
2675
+
2676
+ # results = {}
2677
+
2678
+ # for variant_name, description in variants_to_test:
2679
+ # print(f"\n{'='*60}")
2680
+ # print(f"Testing Variant: {variant_name.upper()}")
2681
+ # print(f"Description: {description}")
2682
+ # print("=" * 60)
2683
+
2684
+ # # Create the variant
2685
+ # if variant_name == "budgeted":
2686
+ # orchestrator = MaiDxOrchestrator.create_variant(
2687
+ # variant_name,
2688
+ # budget=3000,
2689
+ # model_name="gpt-4.1", # Fixed: Use valid model name
2690
+ # max_iterations=5,
2691
+ # )
2692
+ # else:
2693
+ # orchestrator = MaiDxOrchestrator.create_variant(
2694
+ # variant_name,
2695
+ # model_name="gpt-4.1", # Fixed: Use valid model name
2696
+ # max_iterations=5,
2697
+ # )
2698
+
2699
+ # # Run the diagnostic process
2700
+ # result = orchestrator.run(
2701
+ # initial_case_info=initial_info,
2702
+ # full_case_details=full_case,
2703
+ # ground_truth_diagnosis=ground_truth,
2704
+ # )
2705
+
2706
+ # results[variant_name] = result
2707
+
2708
+ # # Display results
2709
+ # print(f"\nπŸš€ Final Diagnosis: {result.final_diagnosis}")
2710
+ # print(f"🎯 Ground Truth: {result.ground_truth}")
2711
+ # print(f"⭐ Accuracy Score: {result.accuracy_score}/5.0")
2712
+ # print(f" Reasoning: {result.accuracy_reasoning}")
2713
+ # print(f"πŸ’° Total Cost: ${result.total_cost:,}")
2714
+ # print(f"πŸ”„ Iterations: {result.iterations}")
2715
+ # print(f"⏱️ Mode: {orchestrator.mode}")
2716
+
2717
+ # # Demonstrate ensemble approach
2718
+ # print(f"\n{'='*60}")
2719
+ # print("Testing Variant: ENSEMBLE")
2720
+ # print(
2721
+ # "Description: Multiple independent runs with consensus aggregation"
2722
+ # )
2723
+ # print("=" * 60)
2724
+
2725
+ # ensemble_orchestrator = MaiDxOrchestrator.create_variant(
2726
+ # "ensemble",
2727
+ # model_name="gpt-4.1", # Fixed: Use valid model name
2728
+ # max_iterations=3, # Shorter iterations for ensemble
2729
+ # )
2730
+
2731
+ # ensemble_result = ensemble_orchestrator.run_ensemble(
2732
+ # initial_case_info=initial_info,
2733
+ # full_case_details=full_case,
2734
+ # ground_truth_diagnosis=ground_truth,
2735
+ # num_runs=2, # Reduced for demo
2736
+ # )
2737
+
2738
+ # results["ensemble"] = ensemble_result
2739
+
2740
+ # print(
2741
+ # f"\nπŸš€ Ensemble Diagnosis: {ensemble_result.final_diagnosis}"
2742
+ # )
2743
+ # print(f"🎯 Ground Truth: {ensemble_result.ground_truth}")
2744
+ # print(
2745
+ # f"⭐ Ensemble Score: {ensemble_result.accuracy_score}/5.0"
2746
+ # )
2747
+ # print(
2748
+ # f"πŸ’° Total Ensemble Cost: ${ensemble_result.total_cost:,}"
2749
+ # )
2750
+
2751
+ # # --- Summary Comparison ---
2752
+ # print(f"\n{'='*80}")
2753
+ # print(" RESULTS SUMMARY")
2754
+ # print("=" * 80)
2755
+ # print(
2756
+ # f"{'Variant':<15} {'Diagnosis Match':<15} {'Score':<8} {'Cost':<12} {'Iterations':<12}"
2757
+ # )
2758
+ # print("-" * 80)
2759
+
2760
+ # for variant_name, result in results.items():
2761
+ # match_status = (
2762
+ # "βœ“ Match"
2763
+ # if result.accuracy_score >= 4.0
2764
+ # else "βœ— No Match"
2765
+ # )
2766
+ # print(
2767
+ # f"{variant_name:<15} {match_status:<15} {result.accuracy_score:<8.1f} ${result.total_cost:<11,} {result.iterations:<12}"
2768
+ # )
2769
+
2770
+ # print(f"\n{'='*80}")
2771
+ # print(
2772
+ # "Implementation successfully demonstrates the MAI-DxO framework"
2773
+ # )
2774
+ # print(
2775
+ # "as described in 'Sequential Diagnosis with Language Models' paper"
2776
+ # )
2777
+ # print("=" * 80)
2778
+
2779
+ # except Exception as e:
2780
+ # logger.exception(
2781
+ # f"An error occurred during the diagnostic session: {e}"
2782
+ # )
2783
+ # print(f"\n❌ Error occurred: {e}")
2784
+ # print("Please check your model configuration and API keys.")