mgbam commited on
Commit
495a355
·
verified ·
1 Parent(s): a97faaa

Update modules/orchestrator.py

Browse files
Files changed (1) hide show
  1. modules/orchestrator.py +30 -64
modules/orchestrator.py CHANGED
@@ -4,7 +4,7 @@ The Central Nervous System of Project Asclepius.
4
  This module is the master conductor, orchestrating high-performance, asynchronous
5
  workflows for each of the application's features. It intelligently sequences
6
  calls to API clients and the Gemini handler to transform user queries into
7
- comprehensive, synthesized reports.
8
  """
9
 
10
  import asyncio
@@ -14,55 +14,31 @@ from PIL import Image
14
 
15
  # Import all our specialized tools
16
  from . import gemini_handler, prompts, utils
17
-
18
- # ==============================================================================
19
- # CORRECTED LINES: The import path is now an absolute import from the project root.
20
- # The leading dot '.' has been removed.
21
  from api_clients import (
22
  pubmed_client,
23
  clinicaltrials_client,
24
  openfda_client,
25
  rxnorm_client
26
  )
27
- # ==============================================================================
28
-
29
 
30
  # --- Internal Helper for Data Formatting ---
31
-
32
  def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
33
- """
34
- Takes the raw dictionary of API results and formats each entry into a
35
- clean, readable string suitable for injection into a Gemini prompt.
36
-
37
- Args:
38
- api_results (dict): The dictionary of results from asyncio.gather.
39
-
40
- Returns:
41
- dict[str, str]: A dictionary with the same keys but formatted string values.
42
- """
43
  formatted_strings = {}
44
-
45
- # Format PubMed data
46
  pubmed_data = api_results.get('pubmed', [])
47
  if isinstance(pubmed_data, list) and pubmed_data:
48
  lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data]
49
  formatted_strings['pubmed'] = "\n".join(lines)
50
  else:
51
  formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query."
52
-
53
- # Format Clinical Trials data
54
  trials_data = api_results.get('trials', [])
55
  if isinstance(trials_data, list) and trials_data:
56
  lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data]
57
  formatted_strings['trials'] = "\n".join(lines)
58
  else:
59
  formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query."
60
-
61
- # Format OpenFDA Adverse Events data
62
- # This data often comes from multiple queries, so we flatten it.
63
  fda_data = api_results.get('openfda', [])
64
  if isinstance(fda_data, list):
65
- # The result is a list of lists, so we flatten it
66
  all_events = list(chain.from_iterable(filter(None, fda_data)))
67
  if all_events:
68
  lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events]
@@ -71,8 +47,6 @@ def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
71
  formatted_strings['openfda'] = "No specific adverse event data was found for this query."
72
  else:
73
  formatted_strings['openfda'] = "No specific adverse event data was found for this query."
74
-
75
- # Format Vision analysis
76
  vision_data = api_results.get('vision', "")
77
  if isinstance(vision_data, str) and vision_data:
78
  formatted_strings['vision'] = vision_data
@@ -80,81 +54,85 @@ def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
80
  formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}"
81
  else:
82
  formatted_strings['vision'] = ""
83
-
84
  return formatted_strings
85
 
86
 
87
- # --- FEATURE 1: Symptom Synthesizer Pipeline ---
88
 
89
  async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
90
  """The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
91
  if not user_query:
92
  return "Please enter a symptom description or a medical question to begin."
93
 
94
- # STEP 1: AI-Powered Concept Extraction
95
- # Use Gemini to find the core medical terms in the user's natural language query.
96
- term_prompt = prompts.get_term_extraction_prompt(user_query)
 
 
 
 
 
 
 
 
 
97
  concepts_str = await gemini_handler.generate_text_response(term_prompt)
98
  concepts = utils.safe_literal_eval(concepts_str)
99
  if not isinstance(concepts, list) or not concepts:
100
- concepts = [user_query] # Fallback to the raw query if parsing fails
101
 
102
  # Use "OR" for a broader, more inclusive search across APIs
103
  search_query = " OR ".join(f'"{c}"' for c in concepts)
104
 
105
- # STEP 2: Massively Parallel Evidence Gathering
106
- # Launch all API calls concurrently for maximum performance.
 
107
  async with aiohttp.ClientSession() as session:
108
- # Define the portfolio of data we need to collect
109
  tasks = {
110
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
111
  "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
112
  "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)),
113
  }
114
- # If an image is provided, add the vision analysis to our task portfolio
115
  if image_input:
116
  tasks["vision"] = gemini_handler.analyze_image_with_text(
117
  "In the context of the user query, analyze this image objectively. Describe visual features like color, shape, texture, and patterns. Do not diagnose or offer medical advice.", image_input
118
  )
119
-
120
- # Execute all tasks and wait for them all to complete
121
  raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
122
  api_data = dict(zip(tasks.keys(), raw_results))
123
 
124
- # STEP 3: Data Formatting
125
- # Convert the raw JSON/list results into clean, prompt-ready strings.
 
126
  formatted_data = _format_api_data_for_prompt(api_data)
127
 
128
- # STEP 4: The Grand Synthesis
129
- # Feed all the structured, evidence-based data into Gemini for the final report generation.
 
130
  synthesis_prompt = prompts.get_synthesis_prompt(
131
- user_query=user_query,
132
  concepts=concepts,
133
  pubmed_data=formatted_data['pubmed'],
134
  trials_data=formatted_data['trials'],
135
  fda_data=formatted_data['openfda'],
136
  vision_analysis=formatted_data['vision']
137
  )
138
-
139
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
140
 
141
- # STEP 5: Final Delivery
142
- # Prepend the mandatory disclaimer to the AI-generated report.
 
143
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
144
 
145
 
146
  # --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline ---
147
-
148
  async def run_drug_interaction_analysis(drug_list_str: str) -> str:
149
  """The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
150
  if not drug_list_str:
151
  return "Please enter a comma-separated list of medications."
152
-
153
  drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()]
154
  if len(drug_names) < 2:
155
  return "Please enter at least two medications to check for interactions."
156
-
157
- # STEP 1: Concurrent Drug Data Gathering
158
  async with aiohttp.ClientSession() as session:
159
  tasks = {
160
  "interactions": rxnorm_client.run_interaction_check(drug_names),
@@ -162,31 +140,19 @@ async def run_drug_interaction_analysis(drug_list_str: str) -> str:
162
  }
163
  raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
164
  api_data = dict(zip(tasks.keys(), raw_results))
165
-
166
- # STEP 2: Data Formatting for AI Synthesis
167
  interaction_data = api_data.get('interactions', [])
168
  if isinstance(interaction_data, Exception):
169
  interaction_data = [{"error": str(interaction_data)}]
170
-
171
  safety_profiles = api_data.get('safety_profiles', [])
172
  if isinstance(safety_profiles, Exception):
173
  safety_profiles = [{"error": str(safety_profiles)}]
174
-
175
- # Combine safety profiles with their drug names for clarity in the prompt
176
  safety_data_dict = dict(zip(drug_names, safety_profiles))
177
-
178
- # Format the complex data into clean strings
179
  interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found."
180
  safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()])
181
-
182
- # STEP 3: AI-Powered Safety Briefing
183
  synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
184
  drug_names=drug_names,
185
  interaction_data=interaction_formatted,
186
  safety_data=safety_formatted
187
  )
188
-
189
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
190
-
191
- # STEP 4: Final Delivery
192
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
 
4
  This module is the master conductor, orchestrating high-performance, asynchronous
5
  workflows for each of the application's features. It intelligently sequences
6
  calls to API clients and the Gemini handler to transform user queries into
7
+ comprehensive, synthesized reports. (v1.1)
8
  """
9
 
10
  import asyncio
 
14
 
15
  # Import all our specialized tools
16
  from . import gemini_handler, prompts, utils
 
 
 
 
17
  from api_clients import (
18
  pubmed_client,
19
  clinicaltrials_client,
20
  openfda_client,
21
  rxnorm_client
22
  )
 
 
23
 
24
  # --- Internal Helper for Data Formatting ---
25
+ # (This helper function remains unchanged)
26
  def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
 
 
 
 
 
 
 
 
 
 
27
  formatted_strings = {}
 
 
28
  pubmed_data = api_results.get('pubmed', [])
29
  if isinstance(pubmed_data, list) and pubmed_data:
30
  lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data]
31
  formatted_strings['pubmed'] = "\n".join(lines)
32
  else:
33
  formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query."
 
 
34
  trials_data = api_results.get('trials', [])
35
  if isinstance(trials_data, list) and trials_data:
36
  lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data]
37
  formatted_strings['trials'] = "\n".join(lines)
38
  else:
39
  formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query."
 
 
 
40
  fda_data = api_results.get('openfda', [])
41
  if isinstance(fda_data, list):
 
42
  all_events = list(chain.from_iterable(filter(None, fda_data)))
43
  if all_events:
44
  lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events]
 
47
  formatted_strings['openfda'] = "No specific adverse event data was found for this query."
48
  else:
49
  formatted_strings['openfda'] = "No specific adverse event data was found for this query."
 
 
50
  vision_data = api_results.get('vision', "")
51
  if isinstance(vision_data, str) and vision_data:
52
  formatted_strings['vision'] = vision_data
 
54
  formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}"
55
  else:
56
  formatted_strings['vision'] = ""
 
57
  return formatted_strings
58
 
59
 
60
+ # --- FEATURE 1: Symptom Synthesizer Pipeline (v1.1) ---
61
 
62
  async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
63
  """The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
64
  if not user_query:
65
  return "Please enter a symptom description or a medical question to begin."
66
 
67
+ # ==============================================================================
68
+ # STEP 1 (V1.1 UPGRADE): AI-Powered Query Correction (The "Medical Translator")
69
+ # ==============================================================================
70
+ correction_prompt = prompts.get_query_correction_prompt(user_query)
71
+ corrected_query = await gemini_handler.generate_text_response(correction_prompt)
72
+ if not corrected_query:
73
+ corrected_query = user_query # Fallback to original query if correction fails
74
+
75
+ # ==============================================================================
76
+ # STEP 2: AI-Powered Concept Extraction (now on the CLEANED query)
77
+ # ==============================================================================
78
+ term_prompt = prompts.get_term_extraction_prompt(corrected_query)
79
  concepts_str = await gemini_handler.generate_text_response(term_prompt)
80
  concepts = utils.safe_literal_eval(concepts_str)
81
  if not isinstance(concepts, list) or not concepts:
82
+ concepts = [corrected_query] # Fallback if parsing fails
83
 
84
  # Use "OR" for a broader, more inclusive search across APIs
85
  search_query = " OR ".join(f'"{c}"' for c in concepts)
86
 
87
+ # ==============================================================================
88
+ # STEP 3: Massively Parallel Evidence Gathering
89
+ # ==============================================================================
90
  async with aiohttp.ClientSession() as session:
 
91
  tasks = {
92
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
93
  "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
94
  "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)),
95
  }
 
96
  if image_input:
97
  tasks["vision"] = gemini_handler.analyze_image_with_text(
98
  "In the context of the user query, analyze this image objectively. Describe visual features like color, shape, texture, and patterns. Do not diagnose or offer medical advice.", image_input
99
  )
 
 
100
  raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
101
  api_data = dict(zip(tasks.keys(), raw_results))
102
 
103
+ # ==============================================================================
104
+ # STEP 4: Data Formatting
105
+ # ==============================================================================
106
  formatted_data = _format_api_data_for_prompt(api_data)
107
 
108
+ # ==============================================================================
109
+ # STEP 5: The Grand Synthesis
110
+ # ==============================================================================
111
  synthesis_prompt = prompts.get_synthesis_prompt(
112
+ user_query=user_query, # Pass original query for context
113
  concepts=concepts,
114
  pubmed_data=formatted_data['pubmed'],
115
  trials_data=formatted_data['trials'],
116
  fda_data=formatted_data['openfda'],
117
  vision_analysis=formatted_data['vision']
118
  )
 
119
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
120
 
121
+ # ==============================================================================
122
+ # STEP 6: Final Delivery
123
+ # ==============================================================================
124
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
125
 
126
 
127
  # --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline ---
128
+ # (This function remains unchanged)
129
  async def run_drug_interaction_analysis(drug_list_str: str) -> str:
130
  """The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
131
  if not drug_list_str:
132
  return "Please enter a comma-separated list of medications."
 
133
  drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()]
134
  if len(drug_names) < 2:
135
  return "Please enter at least two medications to check for interactions."
 
 
136
  async with aiohttp.ClientSession() as session:
137
  tasks = {
138
  "interactions": rxnorm_client.run_interaction_check(drug_names),
 
140
  }
141
  raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
142
  api_data = dict(zip(tasks.keys(), raw_results))
 
 
143
  interaction_data = api_data.get('interactions', [])
144
  if isinstance(interaction_data, Exception):
145
  interaction_data = [{"error": str(interaction_data)}]
 
146
  safety_profiles = api_data.get('safety_profiles', [])
147
  if isinstance(safety_profiles, Exception):
148
  safety_profiles = [{"error": str(safety_profiles)}]
 
 
149
  safety_data_dict = dict(zip(drug_names, safety_profiles))
 
 
150
  interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found."
151
  safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()])
 
 
152
  synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
153
  drug_names=drug_names,
154
  interaction_data=interaction_formatted,
155
  safety_data=safety_formatted
156
  )
 
157
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
 
 
158
  return f"{prompts.DISCLAIMER}\n\n{final_report}"