Spaces:
Sleeping
Sleeping
# modules/orchestrator.py | |
""" | |
The Central Nervous System of Project Asclepius. | |
This module is the master conductor, orchestrating high-performance, asynchronous | |
workflows for each of the application's features. It intelligently sequences | |
calls to API clients and the Gemini handler to transform user queries into | |
comprehensive, synthesized reports. (v1.1) | |
""" | |
import asyncio | |
import aiohttp | |
from itertools import chain | |
from PIL import Image | |
# Import all our specialized tools | |
from . import gemini_handler, prompts, utils | |
from api_clients import ( | |
pubmed_client, | |
clinicaltrials_client, | |
openfda_client, | |
rxnorm_client | |
) | |
# --- Internal Helper for Data Formatting --- | |
# (This helper function remains unchanged) | |
def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]: | |
formatted_strings = {} | |
pubmed_data = api_results.get('pubmed', []) | |
if isinstance(pubmed_data, list) and pubmed_data: | |
lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data] | |
formatted_strings['pubmed'] = "\n".join(lines) | |
else: | |
formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query." | |
trials_data = api_results.get('trials', []) | |
if isinstance(trials_data, list) and trials_data: | |
lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data] | |
formatted_strings['trials'] = "\n".join(lines) | |
else: | |
formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query." | |
fda_data = api_results.get('openfda', []) | |
if isinstance(fda_data, list): | |
all_events = list(chain.from_iterable(filter(None, fda_data))) | |
if all_events: | |
lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events] | |
formatted_strings['openfda'] = "\n".join(lines) | |
else: | |
formatted_strings['openfda'] = "No specific adverse event data was found for this query." | |
else: | |
formatted_strings['openfda'] = "No specific adverse event data was found for this query." | |
vision_data = api_results.get('vision', "") | |
if isinstance(vision_data, str) and vision_data: | |
formatted_strings['vision'] = vision_data | |
elif isinstance(vision_data, Exception): | |
formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}" | |
else: | |
formatted_strings['vision'] = "" | |
return formatted_strings | |
# --- FEATURE 1: Symptom Synthesizer Pipeline (v1.1) --- | |
async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str: | |
"""The complete, asynchronous pipeline for the Symptom Synthesizer tab.""" | |
if not user_query: | |
return "Please enter a symptom description or a medical question to begin." | |
# ============================================================================== | |
# STEP 1 (V1.1 UPGRADE): AI-Powered Query Correction (The "Medical Translator") | |
# ============================================================================== | |
correction_prompt = prompts.get_query_correction_prompt(user_query) | |
corrected_query = await gemini_handler.generate_text_response(correction_prompt) | |
if not corrected_query: | |
corrected_query = user_query # Fallback to original query if correction fails | |
# ============================================================================== | |
# STEP 2: AI-Powered Concept Extraction (now on the CLEANED query) | |
# ============================================================================== | |
term_prompt = prompts.get_term_extraction_prompt(corrected_query) | |
concepts_str = await gemini_handler.generate_text_response(term_prompt) | |
concepts = utils.safe_literal_eval(concepts_str) | |
if not isinstance(concepts, list) or not concepts: | |
concepts = [corrected_query] # Fallback if parsing fails | |
# Use "OR" for a broader, more inclusive search across APIs | |
search_query = " OR ".join(f'"{c}"' for c in concepts) | |
# ============================================================================== | |
# STEP 3: Massively Parallel Evidence Gathering | |
# ============================================================================== | |
async with aiohttp.ClientSession() as session: | |
tasks = { | |
"pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3), | |
"trials": clinicaltrials_client.find_trials(session, search_query, max_results=3), | |
"openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)), | |
} | |
if image_input: | |
tasks["vision"] = gemini_handler.analyze_image_with_text( | |
"In the context of the user query, analyze this image objectively. Describe visual features like color, shape, texture, and patterns. Do not diagnose or offer medical advice.", image_input | |
) | |
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True) | |
api_data = dict(zip(tasks.keys(), raw_results)) | |
# ============================================================================== | |
# STEP 4: Data Formatting | |
# ============================================================================== | |
formatted_data = _format_api_data_for_prompt(api_data) | |
# ============================================================================== | |
# STEP 5: The Grand Synthesis | |
# ============================================================================== | |
synthesis_prompt = prompts.get_synthesis_prompt( | |
user_query=user_query, # Pass original query for context | |
concepts=concepts, | |
pubmed_data=formatted_data['pubmed'], | |
trials_data=formatted_data['trials'], | |
fda_data=formatted_data['openfda'], | |
vision_analysis=formatted_data['vision'] | |
) | |
final_report = await gemini_handler.generate_text_response(synthesis_prompt) | |
# ============================================================================== | |
# STEP 6: Final Delivery | |
# ============================================================================== | |
return f"{prompts.DISCLAIMER}\n\n{final_report}" | |
# --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline --- | |
# (This function remains unchanged) | |
async def run_drug_interaction_analysis(drug_list_str: str) -> str: | |
"""The complete, asynchronous pipeline for the Drug Interaction Analyzer tab.""" | |
if not drug_list_str: | |
return "Please enter a comma-separated list of medications." | |
drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()] | |
if len(drug_names) < 2: | |
return "Please enter at least two medications to check for interactions." | |
async with aiohttp.ClientSession() as session: | |
tasks = { | |
"interactions": rxnorm_client.run_interaction_check(drug_names), | |
"safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names)) | |
} | |
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True) | |
api_data = dict(zip(tasks.keys(), raw_results)) | |
interaction_data = api_data.get('interactions', []) | |
if isinstance(interaction_data, Exception): | |
interaction_data = [{"error": str(interaction_data)}] | |
safety_profiles = api_data.get('safety_profiles', []) | |
if isinstance(safety_profiles, Exception): | |
safety_profiles = [{"error": str(safety_profiles)}] | |
safety_data_dict = dict(zip(drug_names, safety_profiles)) | |
interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found." | |
safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()]) | |
synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt( | |
drug_names=drug_names, | |
interaction_data=interaction_formatted, | |
safety_data=safety_formatted | |
) | |
final_report = await gemini_handler.generate_text_response(synthesis_prompt) | |
return f"{prompts.DISCLAIMER}\n\n{final_report}" |