Spaces:
Sleeping
Sleeping
import gradio as gr | |
from Bio import Entrez | |
from transformers import pipeline | |
import spacy | |
import os # For environment variables and file paths | |
# ---------------------------- Configuration ---------------------------- | |
ENTREZ_EMAIL = os.environ.get("ENTREZ_EMAIL", "[email protected]") # Use environment variable, default fallback | |
HUGGINGFACE_API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN", "HUGGINGFACE_API_TOKEN") # Use environment variable, default fallback | |
SUMMARIZATION_MODEL = "facebook/bart-large-cnn" | |
SPACY_MODEL = "en_core_web_sm" | |
# ---------------------------- Global Variables ---------------------------- | |
summarizer = None | |
nlp = None | |
initialization_status = "Initializing..." # Track initialization state | |
# ---------------------------- Helper Functions ---------------------------- | |
def log_error(message: str): | |
"""Logs an error message to the console and a file (if possible).""" | |
print(f"ERROR: {message}") | |
try: | |
with open("error_log.txt", "a") as f: | |
f.write(f"{message}\n") | |
except: | |
print("Couldn't write to error log file.") #If logging fails, still print to console | |
# ---------------------------- Language Model Loading ---------------------------- | |
def load_spacy_model(model_name="en_core_web_sm"): | |
"""Loads the SpaCy language model, downloading it if necessary.""" | |
global initialization_status # To update the initialization status | |
try: | |
print(f"Attempting to load SpaCy model '{model_name}'...") | |
nlp_model = spacy.load(model_name) | |
print(f"Successfully loaded SpaCy model '{model_name}'.") | |
initialization_status += f"\nSpaCy model '{model_name}' loaded." | |
return nlp_model | |
except OSError: | |
print(f"SpaCy model '{model_name}' not found. Downloading...") | |
initialization_status += f"\nSpaCy model '{model_name}' not found. Downloading..." | |
try: | |
import subprocess | |
subprocess.check_call(["python", "-m", "spacy", "download", model_name]) | |
nlp_model = spacy.load(model_name) | |
print(f"Successfully loaded SpaCy model '{model_name}' after downloading.") | |
initialization_status += f"\nSuccessfully loaded SpaCy model '{model_name}' after downloading." | |
return nlp_model | |
except Exception as e: | |
log_error(f"Failed to download or load SpaCy model '{model_name}': {e}") | |
initialization_status += f"\nFailed to download or load SpaCy model '{model_name}': {e}" | |
return None # Indicate failure | |
except Exception as e: | |
log_error(f"Error loading SpaCy model '{model_name}': {e}") | |
initialization_status += f"\nError loading SpaCy model '{model_name}': {e}" | |
return None | |
# ---------------------------- Tool Functions ---------------------------- | |
def search_pubmed(query: str) -> list: | |
"""Searches PubMed and returns a list of article IDs.""" | |
try: | |
Entrez.email = ENTREZ_EMAIL | |
handle = Entrez.esearch(db="pubmed", term=query, retmax="5") | |
record = Entrez.read(handle) | |
handle.close() | |
return record["IdList"] | |
except Exception as e: | |
log_error(f"PubMed search error: {e}") | |
return [f"Error during PubMed search: {e}"] | |
def fetch_abstract(article_id: str) -> str: | |
"""Fetches the abstract for a given PubMed article ID.""" | |
try: | |
Entrez.email = ENTREZ_EMAIL | |
handle = Entrez.efetch(db="pubmed", id=article_id, rettype="abstract", retmode="text") | |
abstract = handle.read() | |
handle.close() | |
return abstract | |
except Exception as e: | |
log_error(f"Error fetching abstract for {article_id}: {e}") | |
return f"Error fetching abstract for {article_id}: {e}" | |
def summarize_abstract(abstract: str) -> str: | |
"""Summarizes an abstract using a transformer model.""" | |
global summarizer | |
if summarizer is None: | |
log_error("Summarizer not initialized.") | |
return "Summarizer not initialized. Check initialization status." | |
try: | |
summary = summarizer(abstract, max_length=130, min_length=30, do_sample=False)[0]['summary_text'] | |
return summary | |
except Exception as e: | |
log_error(f"Summarization error: {e}") | |
return f"Error during summarization: {e}" | |
def extract_entities(text: str) -> list: | |
"""Extracts entities (simplified) using SpaCy.""" | |
global nlp | |
if nlp is None: | |
log_error("SpaCy model not initialized.") | |
return "SpaCy model not initialized. Check initialization status." | |
try: | |
doc = nlp(text) | |
entities = [(ent.text, ent.label_) for ent in doc.ents] | |
return entities | |
except Exception as e: | |
log_error(f"Entity extraction error: {e}") | |
return [f"Error during entity extraction: {e}"] | |
# ---------------------------- Agent Function ---------------------------- | |
def medai_agent(query: str) -> str: | |
"""Orchestrates the medical literature review and summarization.""" | |
article_ids = search_pubmed(query) | |
if isinstance(article_ids, list) and article_ids: | |
results = [] | |
for article_id in article_ids: | |
abstract = fetch_abstract(article_id) | |
if "Error" not in abstract: | |
summary = summarize_abstract(abstract) | |
entities = extract_entities(abstract) | |
results.append(f"**Article ID:** {article_id}\n\n**Summary:** {summary}\n\n**Entities:** {entities}\n\n---\n") | |
else: | |
results.append(f"Error processing article {article_id}: {abstract}\n\n---\n") | |
return "\n".join(results) | |
else: | |
return f"No articles found or error occurred: {article_ids}" | |
# ---------------------------- Initialization and Setup ---------------------------- | |
def setup(): | |
"""Initializes the summarization model and SpaCy model.""" | |
global summarizer, nlp, initialization_status | |
initialization_status = "Initializing..." | |
try: | |
print("Initializing summarization pipeline...") | |
initialization_status += "\nInitializing summarization pipeline..." | |
summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL, token=HUGGINGFACE_API_TOKEN) | |
print("Summarization pipeline initialized.") | |
initialization_status += "\nSummarization pipeline initialized." | |
print("Loading SpaCy model...") | |
initialization_status += "\nLoading SpaCy model..." | |
global nlp | |
nlp = load_spacy_model() # Call the SpaCy loading function. | |
if nlp is None: | |
initialization_status += "\nSpaCy model failed to load. Check the error log." | |
return initialization_status | |
print("SpaCy model loaded.") | |
initialization_status += "\nSpaCy model loaded." | |
initialization_status = "MedAI Agent initialized successfully!" | |
return initialization_status # Return the status message | |
except Exception as e: | |
initialization_status = f"Initialization error: {e}" | |
log_error(initialization_status) | |
return initialization_status # Return the error message | |
# ---------------------------- Gradio Interface ---------------------------- | |
def launch_gradio(): | |
"""Launches the Gradio interface.""" | |
global initialization_status # Allows the function to modify global variable | |
with gr.Blocks() as iface: | |
gr.Markdown("# MedAI: Medical Literature Review and Summarization") | |
status_display = gr.Textbox(value=initialization_status, interactive=False) # Displays initialization status | |
query_input = gr.Textbox(lines=3, placeholder="Enter your medical query (e.g., 'new treatments for diabetes')...") | |
submit_button = gr.Button("Submit") | |
output_results = gr.Markdown() | |
submit_button.click(medai_agent, inputs=query_input, outputs=output_results) | |
status_display.value = setup() # Set the status after running setup | |
iface.launch() | |
# ---------------------------- Main Execution ---------------------------- | |
if __name__ == "__main__": | |
launch_gradio() |