mgbam commited on
Commit
140d9d4
·
verified ·
1 Parent(s): 51d35a8

Update modules/orchestrator.py

Browse files
Files changed (1) hide show
  1. modules/orchestrator.py +114 -69
modules/orchestrator.py CHANGED
@@ -1,103 +1,146 @@
1
  # modules/orchestrator.py
2
  """
3
- The main conductor. This module sequences the calls to APIs and the AI model.
4
- It contains the core application logic for each feature tab, orchestrating
5
- data fetching, processing, and AI synthesis.
 
 
6
  """
 
7
  import asyncio
8
  import aiohttp
9
- import ast
10
  from itertools import chain
11
  from PIL import Image
12
 
13
- # Import all our tools
14
- from . import gemini_handler, prompts
15
  from .api_clients import (
16
- umls_client,
17
  pubmed_client,
18
  clinicaltrials_client,
19
  openfda_client,
20
  rxnorm_client
 
21
  )
22
 
23
- # --- Helper function for formatting data for prompts ---
24
- def _format_data_for_prompt(data: list | dict, source_name: str) -> str:
25
- """Converts API result lists/dicts into a clean string for Gemini prompts."""
26
- if not data:
27
- return f"No data found from {source_name}."
28
-
29
- report_lines = [f"--- Data from {source_name} ---"]
30
- if isinstance(data, list):
31
- for item in data:
32
- report_lines.append(str(item))
33
- elif isinstance(data, dict):
34
- for key, value in data.items():
35
- report_lines.append(f"{key}: {value}")
36
-
37
- return "\n".join(report_lines)
38
-
39
-
40
- # --- Main Orchestrator for the Symptom Synthesizer ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
42
  """The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
43
  if not user_query:
44
  return "Please enter a symptom description or a medical question to begin."
45
 
46
- # 1. Extract concepts with Gemini
47
- term_extraction_prompt = prompts.get_term_extraction_prompt(user_query)
48
- concepts_str = await gemini_handler.generate_text_response(term_extraction_prompt)
49
- try:
50
- concepts = ast.literal_eval(concepts_str)
51
- if not isinstance(concepts, list) or not concepts:
52
- concepts = [user_query] # Fallback
53
- except (ValueError, SyntaxError):
54
- concepts = [user_query] # Fallback
55
 
56
- search_query = " OR ".join(concepts)
 
57
 
58
- # 2. Gather all evidence concurrently
 
59
  async with aiohttp.ClientSession() as session:
 
60
  tasks = {
61
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
62
  "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
63
- "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts))
64
  }
 
65
  if image_input:
66
  tasks["vision"] = gemini_handler.analyze_image_with_text(
67
- "Analyze this image in a medical context. Describe what you see objectively. Do not diagnose.", image_input
68
  )
69
 
70
- results = await asyncio.gather(*tasks.values(), return_exceptions=True)
71
- api_data = dict(zip(tasks.keys(), results))
72
-
73
- # 3. Format all gathered data for the final prompt
74
- pubmed_formatted = _format_data_for_prompt(api_data.get('pubmed'), "PubMed")
75
- trials_formatted = _format_data_for_prompt(api_data.get('trials'), "ClinicalTrials.gov")
76
-
77
- # Flatten the list of lists from the OpenFDA gather call
78
- fda_results = list(chain.from_iterable(api_data.get('openfda', [])))
79
- fda_formatted = _format_data_for_prompt(fda_results, "OpenFDA Adverse Events")
80
-
81
- vision_formatted = api_data.get('vision', "")
82
- if isinstance(vision_formatted, Exception):
83
- vision_formatted = "Error analyzing image."
84
 
85
- # 4. The Grand Synthesis with Gemini
 
86
  synthesis_prompt = prompts.get_synthesis_prompt(
87
  user_query=user_query,
88
  concepts=concepts,
89
- pubmed_data=pubmed_formatted,
90
- trials_data=trials_formatted,
91
- fda_data=fda_formatted,
92
- vision_analysis=vision_formatted
93
  )
94
 
95
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
96
 
 
 
97
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
98
 
99
 
100
- # --- Main Orchestrator for the Drug Interaction Analyzer ---
 
101
  async def run_drug_interaction_analysis(drug_list_str: str) -> str:
102
  """The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
103
  if not drug_list_str:
@@ -107,16 +150,16 @@ async def run_drug_interaction_analysis(drug_list_str: str) -> str:
107
  if len(drug_names) < 2:
108
  return "Please enter at least two medications to check for interactions."
109
 
110
- # 1. Gather all drug data concurrently
111
  async with aiohttp.ClientSession() as session:
112
  tasks = {
113
  "interactions": rxnorm_client.run_interaction_check(drug_names),
114
  "safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names))
115
  }
116
- results = await asyncio.gather(*tasks.values(), return_exceptions=True)
117
- api_data = dict(zip(tasks.keys(), results))
118
 
119
- # 2. Format data for the final prompt
120
  interaction_data = api_data.get('interactions', [])
121
  if isinstance(interaction_data, Exception):
122
  interaction_data = [{"error": str(interaction_data)}]
@@ -124,20 +167,22 @@ async def run_drug_interaction_analysis(drug_list_str: str) -> str:
124
  safety_profiles = api_data.get('safety_profiles', [])
125
  if isinstance(safety_profiles, Exception):
126
  safety_profiles = [{"error": str(safety_profiles)}]
127
-
128
- # Combine safety profiles with their drug names
129
  safety_data_dict = dict(zip(drug_names, safety_profiles))
130
 
131
- interaction_formatted = _format_data_for_prompt(interaction_data, "RxNorm Interactions")
132
- safety_formatted = _format_data_for_prompt(safety_data_dict, "OpenFDA Safety Profiles")
133
-
134
- # 3. Synthesize the safety report with Gemini
 
135
  synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
136
  drug_names=drug_names,
137
  interaction_data=interaction_formatted,
138
  safety_data=safety_formatted
139
  )
140
-
141
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
142
 
 
143
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
 
1
  # modules/orchestrator.py
2
  """
3
+ The Central Nervous System of Project Asclepius.
4
+ This module is the master conductor, orchestrating high-performance, asynchronous
5
+ workflows for each of the application's features. It intelligently sequences
6
+ calls to API clients and the Gemini handler to transform user queries into
7
+ comprehensive, synthesized reports.
8
  """
9
+
10
  import asyncio
11
  import aiohttp
 
12
  from itertools import chain
13
  from PIL import Image
14
 
15
+ # Import all our specialized tools
16
+ from . import gemini_handler, prompts, utils
17
  from .api_clients import (
 
18
  pubmed_client,
19
  clinicaltrials_client,
20
  openfda_client,
21
  rxnorm_client
22
+ # The umls_client is implicitly used via term extraction, but can be added for deeper analysis
23
  )
24
 
25
+
26
+ # --- Internal Helper for Data Formatting ---
27
+
28
+ def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
29
+ """
30
+ Takes the raw dictionary of API results and formats each entry into a
31
+ clean, readable string suitable for injection into a Gemini prompt.
32
+
33
+ Args:
34
+ api_results (dict): The dictionary of results from asyncio.gather.
35
+
36
+ Returns:
37
+ dict[str, str]: A dictionary with the same keys but formatted string values.
38
+ """
39
+ formatted_strings = {}
40
+
41
+ # Format PubMed data
42
+ pubmed_data = api_results.get('pubmed', [])
43
+ if isinstance(pubmed_data, list) and pubmed_data:
44
+ lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data]
45
+ formatted_strings['pubmed'] = "\n".join(lines)
46
+ else:
47
+ formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query."
48
+
49
+ # Format Clinical Trials data
50
+ trials_data = api_results.get('trials', [])
51
+ if isinstance(trials_data, list) and trials_data:
52
+ lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data]
53
+ formatted_strings['trials'] = "\n".join(lines)
54
+ else:
55
+ formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query."
56
+
57
+ # Format OpenFDA Adverse Events data
58
+ # This data often comes from multiple queries, so we flatten it.
59
+ fda_data = api_results.get('openfda', [])
60
+ if isinstance(fda_data, list):
61
+ # The result is a list of lists, so we flatten it
62
+ all_events = list(chain.from_iterable(filter(None, fda_data)))
63
+ if all_events:
64
+ lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events]
65
+ formatted_strings['openfda'] = "\n".join(lines)
66
+ else:
67
+ formatted_strings['openfda'] = "No specific adverse event data was found for this query."
68
+ else:
69
+ formatted_strings['openfda'] = "No specific adverse event data was found for this query."
70
+
71
+ # Format Vision analysis
72
+ vision_data = api_results.get('vision', "")
73
+ if isinstance(vision_data, str) and vision_data:
74
+ formatted_strings['vision'] = vision_data
75
+ elif isinstance(vision_data, Exception):
76
+ formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}"
77
+ else:
78
+ formatted_strings['vision'] = ""
79
+
80
+ return formatted_strings
81
+
82
+
83
+ # --- FEATURE 1: Symptom Synthesizer Pipeline ---
84
+
85
  async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
86
  """The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
87
  if not user_query:
88
  return "Please enter a symptom description or a medical question to begin."
89
 
90
+ # STEP 1: AI-Powered Concept Extraction
91
+ # Use Gemini to find the core medical terms in the user's natural language query.
92
+ term_prompt = prompts.get_term_extraction_prompt(user_query)
93
+ concepts_str = await gemini_handler.generate_text_response(term_prompt)
94
+ concepts = utils.safe_literal_eval(concepts_str)
95
+ if not isinstance(concepts, list) or not concepts:
96
+ concepts = [user_query] # Fallback to the raw query if parsing fails
 
 
97
 
98
+ # Use "OR" for a broader, more inclusive search across APIs
99
+ search_query = " OR ".join(f'"{c}"' for c in concepts)
100
 
101
+ # STEP 2: Massively Parallel Evidence Gathering
102
+ # Launch all API calls concurrently for maximum performance.
103
  async with aiohttp.ClientSession() as session:
104
+ # Define the portfolio of data we need to collect
105
  tasks = {
106
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
107
  "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
108
+ "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)),
109
  }
110
+ # If an image is provided, add the vision analysis to our task portfolio
111
  if image_input:
112
  tasks["vision"] = gemini_handler.analyze_image_with_text(
113
+ "In the context of the user query, analyze this image objectively. Describe visual features like color, shape, texture, and patterns. Do not diagnose or offer medical advice.", image_input
114
  )
115
 
116
+ # Execute all tasks and wait for them all to complete
117
+ raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
118
+ api_data = dict(zip(tasks.keys(), raw_results))
119
+
120
+ # STEP 3: Data Formatting
121
+ # Convert the raw JSON/list results into clean, prompt-ready strings.
122
+ formatted_data = _format_api_data_for_prompt(api_data)
 
 
 
 
 
 
 
123
 
124
+ # STEP 4: The Grand Synthesis
125
+ # Feed all the structured, evidence-based data into Gemini for the final report generation.
126
  synthesis_prompt = prompts.get_synthesis_prompt(
127
  user_query=user_query,
128
  concepts=concepts,
129
+ pubmed_data=formatted_data['pubmed'],
130
+ trials_data=formatted_data['trials'],
131
+ fda_data=formatted_data['openfda'],
132
+ vision_analysis=formatted_data['vision']
133
  )
134
 
135
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
136
 
137
+ # STEP 5: Final Delivery
138
+ # Prepend the mandatory disclaimer to the AI-generated report.
139
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
140
 
141
 
142
+ # --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline ---
143
+
144
  async def run_drug_interaction_analysis(drug_list_str: str) -> str:
145
  """The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
146
  if not drug_list_str:
 
150
  if len(drug_names) < 2:
151
  return "Please enter at least two medications to check for interactions."
152
 
153
+ # STEP 1: Concurrent Drug Data Gathering
154
  async with aiohttp.ClientSession() as session:
155
  tasks = {
156
  "interactions": rxnorm_client.run_interaction_check(drug_names),
157
  "safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names))
158
  }
159
+ raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
160
+ api_data = dict(zip(tasks.keys(), raw_results))
161
 
162
+ # STEP 2: Data Formatting for AI Synthesis
163
  interaction_data = api_data.get('interactions', [])
164
  if isinstance(interaction_data, Exception):
165
  interaction_data = [{"error": str(interaction_data)}]
 
167
  safety_profiles = api_data.get('safety_profiles', [])
168
  if isinstance(safety_profiles, Exception):
169
  safety_profiles = [{"error": str(safety_profiles)}]
170
+
171
+ # Combine safety profiles with their drug names for clarity in the prompt
172
  safety_data_dict = dict(zip(drug_names, safety_profiles))
173
 
174
+ # Format the complex data into clean strings
175
+ interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found."
176
+ safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()])
177
+
178
+ # STEP 3: AI-Powered Safety Briefing
179
  synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
180
  drug_names=drug_names,
181
  interaction_data=interaction_formatted,
182
  safety_data=safety_formatted
183
  )
184
+
185
  final_report = await gemini_handler.generate_text_response(synthesis_prompt)
186
 
187
+ # STEP 4: Final Delivery
188
  return f"{prompts.DISCLAIMER}\n\n{final_report}"