mgbam commited on
Commit
89db8f5
·
verified ·
1 Parent(s): aa3b805

Update modules/orchestrator.py

Browse files
Files changed (1) hide show
  1. modules/orchestrator.py +109 -46
modules/orchestrator.py CHANGED
@@ -1,80 +1,143 @@
1
  # modules/orchestrator.py
2
  """
3
  The main conductor. This module sequences the calls to APIs and the AI model.
4
- It's the heart of the application's logic.
 
5
  """
6
  import asyncio
7
  import aiohttp
8
  import ast
 
 
 
 
9
  from . import gemini_handler, prompts
10
- from .api_clients import umls_client, pubmed_client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- async def run_symptom_synthesis(user_query: str, image_input=None):
13
- """
14
- The complete pipeline for the Symptom Synthesizer tab.
15
- """
16
  if not user_query:
17
- return "Please enter your symptoms or query."
18
 
19
- # --- Step 1: Extract Key Concepts with Gemini ---
20
  term_extraction_prompt = prompts.get_term_extraction_prompt(user_query)
21
- concepts_str = await gemini_handler.generate_gemini_response(term_extraction_prompt)
22
  try:
23
- # Safely evaluate the string representation of the list
24
  concepts = ast.literal_eval(concepts_str)
25
- if not isinstance(concepts, list):
26
  concepts = [user_query] # Fallback
27
  except (ValueError, SyntaxError):
28
- concepts = [user_query] # Fallback if Gemini doesn't return a perfect list
29
 
30
- search_query = " AND ".join(concepts)
31
 
32
- # --- Step 2: Gather Evidence Asynchronously ---
33
  async with aiohttp.ClientSession() as session:
34
- # Create a UMLS client instance for this session
35
- umls = umls_client.UMLSClient(session)
36
-
37
- # Define all async tasks
38
  tasks = {
39
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
40
- "umls_cui": umls.get_cui_for_term(concepts[0] if concepts else user_query),
41
- # Add other clients here as they are built e.g.,
42
- # "trials": clinicaltrials_client.find_trials(session, search_query),
43
- # "fda": openfda_client.get_adverse_events(session, concepts)
44
  }
 
 
 
 
45
 
46
- # Run all tasks concurrently
47
  results = await asyncio.gather(*tasks.values(), return_exceptions=True)
48
-
49
- # Map results back to their keys, handling potential errors
50
  api_data = dict(zip(tasks.keys(), results))
51
- for key, value in api_data.items():
52
- if isinstance(value, Exception):
53
- print(f"Error fetching data from {key}: {value}")
54
- api_data[key] = None # Nullify data if fetch failed
55
-
56
- # --- Step 3: Format Data for the Synthesis Prompt ---
57
- # Convert raw JSON/list data into clean, readable strings for the AI
58
- pubmed_formatted = "\n".join([f"- Title: {a.get('title', 'N/A')}, PMID: {a.get('uid', 'N/A')}" for a in api_data.get('pubmed', [])])
59
 
60
- # In a real implementation, you'd format trials and fda data here too
61
- trials_formatted = "Trial data fetching is not yet fully implemented in this demo."
62
- fda_formatted = "FDA data fetching is not yet fully implemented in this demo."
63
-
 
 
 
64
 
65
- # --- Step 4: The Grand Synthesis with Gemini ---
66
  synthesis_prompt = prompts.get_synthesis_prompt(
67
- user_query,
68
- concepts,
69
  pubmed_data=pubmed_formatted,
70
  trials_data=trials_formatted,
71
- fda_data=fda_formatted
 
72
  )
73
 
74
- final_report = await gemini_handler.generate_gemini_response(synthesis_prompt)
75
-
76
- # --- Step 5: Prepend Disclaimer and Return ---
77
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
78
 
79
- # You would create similar orchestrator functions for other tabs
80
- # e.g., async def run_drug_interaction_analysis(drug_list): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # modules/orchestrator.py
2
  """
3
  The main conductor. This module sequences the calls to APIs and the AI model.
4
+ It contains the core application logic for each feature tab, orchestrating
5
+ data fetching, processing, and AI synthesis.
6
  """
7
  import asyncio
8
  import aiohttp
9
  import ast
10
+ from itertools import chain
11
+ from PIL import Image
12
+
13
+ # Import all our tools
14
  from . import gemini_handler, prompts
15
+ from .api_clients import (
16
+ umls_client,
17
+ pubmed_client,
18
+ clinicaltrials_client,
19
+ openfda_client,
20
+ rxnorm_client
21
+ )
22
+
23
+ # --- Helper function for formatting data for prompts ---
24
+ def _format_data_for_prompt(data: list | dict, source_name: str) -> str:
25
+ """Converts API result lists/dicts into a clean string for Gemini prompts."""
26
+ if not data:
27
+ return f"No data found from {source_name}."
28
+
29
+ report_lines = [f"--- Data from {source_name} ---"]
30
+ if isinstance(data, list):
31
+ for item in data:
32
+ report_lines.append(str(item))
33
+ elif isinstance(data, dict):
34
+ for key, value in data.items():
35
+ report_lines.append(f"{key}: {value}")
36
+
37
+ return "\n".join(report_lines)
38
+
39
 
40
+ # --- Main Orchestrator for the Symptom Synthesizer ---
41
+ async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
42
+ """The complete, asynchronous pipeline for the Symptom Synthesizer tab."""
 
43
  if not user_query:
44
+ return "Please enter a symptom description or a medical question to begin."
45
 
46
+ # 1. Extract concepts with Gemini
47
  term_extraction_prompt = prompts.get_term_extraction_prompt(user_query)
48
+ concepts_str = await gemini_handler.generate_text_response(term_extraction_prompt)
49
  try:
 
50
  concepts = ast.literal_eval(concepts_str)
51
+ if not isinstance(concepts, list) or not concepts:
52
  concepts = [user_query] # Fallback
53
  except (ValueError, SyntaxError):
54
+ concepts = [user_query] # Fallback
55
 
56
+ search_query = " OR ".join(concepts)
57
 
58
+ # 2. Gather all evidence concurrently
59
  async with aiohttp.ClientSession() as session:
 
 
 
 
60
  tasks = {
61
  "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3),
62
+ "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3),
63
+ "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts))
 
 
64
  }
65
+ if image_input:
66
+ tasks["vision"] = gemini_handler.analyze_image_with_text(
67
+ "Analyze this image in a medical context. Describe what you see objectively. Do not diagnose.", image_input
68
+ )
69
 
 
70
  results = await asyncio.gather(*tasks.values(), return_exceptions=True)
 
 
71
  api_data = dict(zip(tasks.keys(), results))
72
+
73
+ # 3. Format all gathered data for the final prompt
74
+ pubmed_formatted = _format_data_for_prompt(api_data.get('pubmed'), "PubMed")
75
+ trials_formatted = _format_data_for_prompt(api_data.get('trials'), "ClinicalTrials.gov")
 
 
 
 
76
 
77
+ # Flatten the list of lists from the OpenFDA gather call
78
+ fda_results = list(chain.from_iterable(api_data.get('openfda', [])))
79
+ fda_formatted = _format_data_for_prompt(fda_results, "OpenFDA Adverse Events")
80
+
81
+ vision_formatted = api_data.get('vision', "")
82
+ if isinstance(vision_formatted, Exception):
83
+ vision_formatted = "Error analyzing image."
84
 
85
+ # 4. The Grand Synthesis with Gemini
86
  synthesis_prompt = prompts.get_synthesis_prompt(
87
+ user_query=user_query,
88
+ concepts=concepts,
89
  pubmed_data=pubmed_formatted,
90
  trials_data=trials_formatted,
91
+ fda_data=fda_formatted,
92
+ vision_analysis=vision_formatted
93
  )
94
 
95
+ final_report = await gemini_handler.generate_text_response(synthesis_prompt)
96
+
 
97
  return f"{prompts.DISCLAIMER}\n\n{final_report}"
98
 
99
+
100
+ # --- Main Orchestrator for the Drug Interaction Analyzer ---
101
+ async def run_drug_interaction_analysis(drug_list_str: str) -> str:
102
+ """The complete, asynchronous pipeline for the Drug Interaction Analyzer tab."""
103
+ if not drug_list_str:
104
+ return "Please enter a comma-separated list of medications."
105
+
106
+ drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()]
107
+ if len(drug_names) < 2:
108
+ return "Please enter at least two medications to check for interactions."
109
+
110
+ # 1. Gather all drug data concurrently
111
+ async with aiohttp.ClientSession() as session:
112
+ tasks = {
113
+ "interactions": rxnorm_client.run_interaction_check(drug_names),
114
+ "safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names))
115
+ }
116
+ results = await asyncio.gather(*tasks.values(), return_exceptions=True)
117
+ api_data = dict(zip(tasks.keys(), results))
118
+
119
+ # 2. Format data for the final prompt
120
+ interaction_data = api_data.get('interactions', [])
121
+ if isinstance(interaction_data, Exception):
122
+ interaction_data = [{"error": str(interaction_data)}]
123
+
124
+ safety_profiles = api_data.get('safety_profiles', [])
125
+ if isinstance(safety_profiles, Exception):
126
+ safety_profiles = [{"error": str(safety_profiles)}]
127
+
128
+ # Combine safety profiles with their drug names
129
+ safety_data_dict = dict(zip(drug_names, safety_profiles))
130
+
131
+ interaction_formatted = _format_data_for_prompt(interaction_data, "RxNorm Interactions")
132
+ safety_formatted = _format_data_for_prompt(safety_data_dict, "OpenFDA Safety Profiles")
133
+
134
+ # 3. Synthesize the safety report with Gemini
135
+ synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(
136
+ drug_names=drug_names,
137
+ interaction_data=interaction_formatted,
138
+ safety_data=safety_formatted
139
+ )
140
+
141
+ final_report = await gemini_handler.generate_text_response(synthesis_prompt)
142
+
143
+ return f"{prompts.DISCLAIMER}\n\n{final_report}"