project-asclepius / modules /orchestrator.py
mgbam's picture
Update modules/orchestrator.py
666355f verified
# modules/orchestrator.py
"""
The Central Nervous System of Project Asclepius.
(v2.0 - The "Clinical Insight Engine" Upgrade)
This version uses a smarter post-processing function to guarantee clean output.
"""
import asyncio
import aiohttp
from itertools import chain
from PIL import Image
from . import gemini_handler, prompts, utils
from api_clients import (
pubmed_client, clinicaltrials_client, openfda_client, rxnorm_client
)
# --- Internal Helper for Data Formatting (Unchanged) ---
def _format_api_data_for_prompt(api_results: dict) -> dict[str, str]:
# This function is unchanged.
formatted_strings = {}
pubmed_data = api_results.get('pubmed', [])
if isinstance(pubmed_data, list) and pubmed_data: lines = [f"- Title: {a.get('title', 'N/A')} (Journal: {a.get('journal', 'N/A')}, URL: {a.get('url')})" for a in pubmed_data]; formatted_strings['pubmed'] = "\n".join(lines)
else: formatted_strings['pubmed'] = "No relevant review articles were found on PubMed for this query."
trials_data = api_results.get('trials', [])
if isinstance(trials_data, list) and trials_data: lines = [f"- Title: {t.get('title', 'N/A')} (Status: {t.get('status', 'N/A')}, URL: {t.get('url')})" for t in trials_data]; formatted_strings['trials'] = "\n".join(lines)
else: formatted_strings['trials'] = "No actively recruiting clinical trials were found matching this query."
fda_data = api_results.get('openfda', [])
if isinstance(fda_data, list):
all_events = list(chain.from_iterable(filter(None, fda_data)))
if all_events: lines = [f"- {evt['term']} (Reported {evt['count']} times)" for evt in all_events]; formatted_strings['openfda'] = "\n".join(lines)
else: formatted_strings['openfda'] = "No specific adverse event data was found for this query."
else: formatted_strings['openfda'] = "No specific adverse event data was found for this query."
vision_data = api_results.get('vision', "")
if isinstance(vision_data, str) and vision_data: formatted_strings['vision'] = vision_data
elif isinstance(vision_data, Exception): formatted_strings['vision'] = f"An error occurred during image analysis: {vision_data}"
else: formatted_strings['vision'] = ""
return formatted_strings
# ==============================================================================
# V2.0 UPGRADE: A robust function to remove any AI-generated preamble/disclaimer.
# ==============================================================================
def _clean_ai_preamble(report_text: str) -> str:
"""Intelligently removes redundant disclaimers or preambles added by the AI."""
lines = report_text.strip().split('\n')
# AI disclaimers are often short, in the first few lines, and contain specific keywords.
# We find the first line that looks like real content (starts with '##' for our format).
start_index = 0
for i, line in enumerate(lines):
if line.strip().startswith('##'):
start_index = i
break
# Failsafe for the first 5 lines if no '##' is found
if i > 5:
break
return '\n'.join(lines[start_index:])
# --- FEATURE 1: Symptom Synthesizer Pipeline (v2.0) ---
async def run_symptom_synthesis(user_query: str, image_input: Image.Image | None) -> str:
# (Steps 1-4 remain the same)
if not user_query: return "Please enter a symptom description or a medical question to begin."
correction_prompt = prompts.get_query_correction_prompt(user_query)
corrected_query = await gemini_handler.generate_text_response(correction_prompt)
if not corrected_query: corrected_query = user_query
term_prompt = prompts.get_term_extraction_prompt(corrected_query)
concepts_str = await gemini_handler.generate_text_response(term_prompt)
concepts = utils.safe_literal_eval(concepts_str)
if not isinstance(concepts, list) or not concepts: concepts = [corrected_query]
search_query = " OR ".join(f'"{c}"' for c in concepts)
async with aiohttp.ClientSession() as session:
tasks = { "pubmed": pubmed_client.search_pubmed(session, search_query, max_results=3), "trials": clinicaltrials_client.find_trials(session, search_query, max_results=3), "openfda": asyncio.gather(*(openfda_client.get_adverse_events(session, c, top_n=3) for c in concepts)), }
if image_input: tasks["vision"] = gemini_handler.analyze_image_with_text("In the context of the user query, analyze this image objectively...", image_input)
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
api_data = dict(zip(tasks.keys(), raw_results))
formatted_data = _format_api_data_for_prompt(api_data)
# STEP 5: The Grand Synthesis (using new v2.0 prompt)
synthesis_prompt = prompts.get_synthesis_prompt(user_query=user_query, concepts=concepts, pubmed_data=formatted_data['pubmed'], trials_data=formatted_data['trials'], fda_data=formatted_data['openfda'], vision_analysis=formatted_data['vision'])
final_report = await gemini_handler.generate_text_response(synthesis_prompt)
# STEP 6: Intelligent Post-Processing
cleaned_report = _clean_ai_preamble(final_report)
# STEP 7: Final Delivery
return f"{prompts.DISCLAIMER}\n\n{cleaned_report}"
# --- FEATURE 2: Drug Interaction & Safety Analyzer Pipeline (v2.0) ---
async def run_drug_interaction_analysis(drug_list_str: str) -> str:
# (Steps remain the same)
if not drug_list_str: return "Please enter a comma-separated list of medications."
drug_names = [name.strip() for name in drug_list_str.split(',') if name.strip()]
if len(drug_names) < 2: return "Please enter at least two medications to check for interactions."
async with aiohttp.ClientSession() as session:
tasks = { "interactions": rxnorm_client.run_interaction_check(drug_names), "safety_profiles": asyncio.gather(*(openfda_client.get_safety_profile(session, name) for name in drug_names)) }
raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True)
api_data = dict(zip(tasks.keys(), raw_results))
interaction_data = api_data.get('interactions', [])
if isinstance(interaction_data, Exception): interaction_data = [{"error": str(interaction_data)}]
safety_profiles = api_data.get('safety_profiles', [])
if isinstance(safety_profiles, Exception): safety_profiles = [{"error": str(safety_profiles)}]
safety_data_dict = dict(zip(drug_names, safety_profiles))
interaction_formatted = utils.format_list_as_markdown([str(i) for i in interaction_data]) if interaction_data else "No interactions found."
safety_formatted = "\n".join([f"Profile for {drug}: {profile}" for drug, profile in safety_data_dict.items()])
# Synthesis (using new v2.0 prompt)
synthesis_prompt = prompts.get_drug_interaction_synthesis_prompt(drug_names=drug_names, interaction_data=interaction_formatted, safety_data=safety_formatted)
final_report = await gemini_handler.generate_text_response(synthesis_prompt)
# Intelligent Post-Processing
cleaned_report = _clean_ai_preamble(final_report)
return f"{prompts.DISCLAIMER}\n\n{cleaned_report}"