mgbam commited on
Commit
19c2c87
·
verified ·
1 Parent(s): cd0b15a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +658 -653
app.py CHANGED
@@ -1,653 +1,658 @@
1
- import os
2
- import json
3
- import csv
4
- import asyncio
5
- import xml.etree.ElementTree as ET
6
- from typing import Any, Dict, Optional, Tuple, Union, List
7
-
8
- import httpx
9
- import gradio as gr
10
- import torch
11
- from dotenv import load_dotenv
12
- from loguru import logger
13
- from huggingface_hub import login
14
- from openai import OpenAI
15
- from reportlab.pdfgen import canvas
16
- from transformers import (
17
- AutoTokenizer,
18
- AutoModelForSequenceClassification,
19
- MarianMTModel,
20
- MarianTokenizer,
21
- )
22
- import pandas as pd
23
- import altair as alt
24
- import spacy
25
- import spacy.cli
26
- import PyPDF2 # For PDF reading
27
-
28
- # Ensure spaCy model is downloaded
29
- try:
30
- nlp = spacy.load("en_core_web_sm")
31
- except OSError:
32
- logger.info("Downloading SpaCy 'en_core_web_sm' model...")
33
- spacy.cli.download("en_core_web_sm")
34
- nlp = spacy.load("en_core_web_sm")
35
-
36
- # Logging
37
- logger.add("error_logs.log", rotation="1 MB", level="ERROR")
38
-
39
- # Load environment variables
40
- load_dotenv()
41
- HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
42
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
43
- ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL")
44
-
45
- # Basic checks
46
- if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY:
47
- logger.error("Missing Hugging Face or OpenAI credentials.")
48
- raise ValueError("Missing credentials for Hugging Face or OpenAI.")
49
-
50
- # API endpoints
51
- PUBMED_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
52
- PUBMED_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
53
- EUROPE_PMC_BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
54
-
55
- # Hugging Face login
56
- login(HUGGINGFACE_TOKEN)
57
-
58
- # Initialize OpenAI
59
- client = OpenAI(api_key=OPENAI_API_KEY)
60
-
61
- # Device setting
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
- logger.info(f"Using device: {device}")
64
-
65
- # Model settings
66
- MODEL_NAME = "mgbam/bert-base-finetuned-mgbam"
67
- try:
68
- model = AutoModelForSequenceClassification.from_pretrained(
69
- MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
70
- ).to(device)
71
- tokenizer = AutoTokenizer.from_pretrained(
72
- MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
73
- )
74
- except Exception as e:
75
- logger.error(f"Model load error: {e}")
76
- raise
77
-
78
- # Translation model settings
79
- try:
80
- translation_model_name = "Helsinki-NLP/opus-mt-en-fr"
81
- translation_model = MarianMTModel.from_pretrained(
82
- translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
83
- ).to(device)
84
- translation_tokenizer = MarianTokenizer.from_pretrained(
85
- translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
86
- )
87
- except Exception as e:
88
- logger.error(f"Translation model load error: {e}")
89
- raise
90
-
91
- LANGUAGE_MAP: Dict[str, Tuple[str, str]] = {
92
- "English to French": ("en", "fr"),
93
- "French to English": ("fr", "en"),
94
- }
95
-
96
- ### Utility Functions ###
97
- def safe_json_parse(text: str) -> Union[Dict, None]:
98
- """Safely parse JSON string into a Python dictionary."""
99
- try:
100
- return json.loads(text)
101
- except json.JSONDecodeError as e:
102
- logger.error(f"JSON parsing error: {e}")
103
- return None
104
-
105
- def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]:
106
- """Parses PubMed XML data and returns a list of structured articles."""
107
- root = ET.fromstring(xml_data)
108
- articles = []
109
- for article in root.findall(".//PubmedArticle"):
110
- pmid = article.findtext(".//PMID")
111
- title = article.findtext(".//ArticleTitle")
112
- abstract = article.findtext(".//AbstractText")
113
- journal = article.findtext(".//Journal/Title")
114
- pub_date_elem = article.find(".//JournalIssue/PubDate")
115
- pub_date = None
116
- if pub_date_elem is not None:
117
- year = pub_date_elem.findtext("Year")
118
- month = pub_date_elem.findtext("Month")
119
- day = pub_date_elem.findtext("Day")
120
- if year and month and day:
121
- pub_date = f"{year}-{month}-{day}"
122
- else:
123
- pub_date = year
124
- articles.append({
125
- "PMID": pmid,
126
- "Title": title,
127
- "Abstract": abstract,
128
- "Journal": journal,
129
- "PublicationDate": pub_date,
130
- })
131
- return articles
132
-
133
- ### Async Functions for Europe PMC ###
134
- async def fetch_articles_by_nct_id(nct_id: str) -> Dict[str, Any]:
135
- params = {"query": nct_id, "format": "json"}
136
- async with httpx.AsyncClient() as client_http:
137
- try:
138
- response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
139
- response.raise_for_status()
140
- return response.json()
141
- except Exception as e:
142
- logger.error(f"Error fetching articles for {nct_id}: {e}")
143
- return {"error": str(e)}
144
-
145
- async def fetch_articles_by_query(query_params: str) -> Dict[str, Any]:
146
- parsed_params = safe_json_parse(query_params)
147
- if not parsed_params or not isinstance(parsed_params, dict):
148
- return {"error": "Invalid JSON."}
149
- query_string = " AND ".join(f"{k}:{v}" for k, v in parsed_params.items())
150
- params = {"query": query_string, "format": "json"}
151
- async with httpx.AsyncClient() as client_http:
152
- try:
153
- response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
154
- response.raise_for_status()
155
- return response.json()
156
- except Exception as e:
157
- logger.error(f"Error fetching articles: {e}")
158
- return {"error": str(e)}
159
-
160
- ### PubMed Integration ###
161
- async def fetch_pubmed_by_query(query_params: str) -> Dict[str, Any]:
162
- parsed_params = safe_json_parse(query_params)
163
- if not parsed_params or not isinstance(parsed_params, dict):
164
- return {"error": "Invalid JSON for PubMed."}
165
-
166
- search_params = {
167
- "db": "pubmed",
168
- "retmode": "json",
169
- "email": ENTREZ_EMAIL,
170
- "retmax": parsed_params.get("retmax", "10"),
171
- "term": parsed_params.get("term", ""),
172
- }
173
-
174
- async with httpx.AsyncClient() as client_http:
175
- try:
176
- search_response = await client_http.get(PUBMED_SEARCH_URL, params=search_params)
177
- search_response.raise_for_status()
178
- search_data = search_response.json()
179
- id_list = search_data.get("esearchresult", {}).get("idlist", [])
180
- if not id_list:
181
- return {"result": ""}
182
-
183
- fetch_params = {
184
- "db": "pubmed",
185
- "id": ",".join(id_list),
186
- "retmode": "xml",
187
- "email": ENTREZ_EMAIL,
188
- }
189
- fetch_response = await client_http.get(PUBMED_FETCH_URL, params=fetch_params)
190
- fetch_response.raise_for_status()
191
- return {"result": fetch_response.text}
192
- except Exception as e:
193
- logger.error(f"Error fetching PubMed articles: {e}")
194
- return {"error": str(e)}
195
-
196
- ### Crossref Integration ###
197
- async def fetch_crossref_by_query(query_params: str) -> Dict[str, Any]:
198
- parsed_params = safe_json_parse(query_params)
199
- if not parsed_params or not isinstance(parsed_params, dict):
200
- return {"error": "Invalid JSON for Crossref."}
201
- CROSSREF_API_URL = "https://api.crossref.org/works"
202
- async with httpx.AsyncClient() as client_http:
203
- try:
204
- response = await client_http.get(CROSSREF_API_URL, params=parsed_params)
205
- response.raise_for_status()
206
- return response.json()
207
- except Exception as e:
208
- logger.error(f"Error fetching Crossref data: {e}")
209
- return {"error": str(e)}
210
-
211
- ### Core Functions ###
212
- def summarize_text(text: str) -> str:
213
- """Summarize text using OpenAI."""
214
- if not text.strip():
215
- return "No text provided for summarization."
216
- try:
217
- response = client.chat.completions.create(
218
- model="gpt-3.5-turbo",
219
- messages=[{"role": "user", "content": f"Summarize the following clinical data:\n{text}"}],
220
- max_tokens=200,
221
- temperature=0.7,
222
- )
223
- return response.choices[0].message.content.strip()
224
- except Exception as e:
225
- logger.error(f"Summarization Error: {e}")
226
- return "Summarization failed."
227
-
228
- def predict_outcome(text: str) -> Union[Dict[str, float], str]:
229
- """Predict outcomes (classification) using a fine-tuned model."""
230
- if not text.strip():
231
- return "No text provided for prediction."
232
- try:
233
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
234
- inputs = {k: v.to(device) for k, v in inputs.items()}
235
- with torch.no_grad():
236
- outputs = model(**inputs)
237
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
238
- return {f"Label {i+1}": float(prob.item()) for i, prob in enumerate(probabilities)}
239
- except Exception as e:
240
- logger.error(f"Prediction Error: {e}")
241
- return "Prediction failed."
242
-
243
- def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]:
244
- """Generate a PDF report from the given text."""
245
- try:
246
- if not text.strip():
247
- logger.warning("No text provided for the report.")
248
- c = canvas.Canvas(filename)
249
- c.drawString(100, 750, "Clinical Research Report")
250
- lines = text.split("\n")
251
- y = 730
252
- for line in lines:
253
- if y < 50:
254
- c.showPage()
255
- y = 750
256
- c.drawString(100, y, line)
257
- y -= 15
258
- c.save()
259
- logger.info(f"Report generated: {filename}")
260
- return filename
261
- except Exception as e:
262
- logger.error(f"Report Generation Error: {e}")
263
- return None
264
-
265
- def visualize_predictions(predictions: Dict[str, float]) -> Optional[alt.Chart]:
266
- """Visualize model prediction probabilities using Altair."""
267
- try:
268
- data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"])
269
- chart = (
270
- alt.Chart(data)
271
- .mark_bar()
272
- .encode(
273
- x=alt.X("Label:N", sort=None),
274
- y="Probability:Q",
275
- tooltip=["Label", "Probability"],
276
- )
277
- .properties(title="Prediction Probabilities", width=500, height=300)
278
- )
279
- return chart
280
- except Exception as e:
281
- logger.error(f"Visualization Error: {e}")
282
- return None
283
-
284
- def translate_text(text: str, translation_option: str) -> str:
285
- """Translate text between English and French."""
286
- if not text.strip():
287
- return "No text provided for translation."
288
- try:
289
- if translation_option not in LANGUAGE_MAP:
290
- return "Unsupported translation option."
291
- inputs = translation_tokenizer(text, return_tensors="pt", padding=True).to(device)
292
- translated_tokens = translation_model.generate(**inputs)
293
- return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
294
- except Exception as e:
295
- logger.error(f"Translation Error: {e}")
296
- return "Translation failed."
297
-
298
- def perform_named_entity_recognition(text: str) -> str:
299
- """Perform Named Entity Recognition (NER) using spaCy."""
300
- if not text.strip():
301
- return "No text provided for NER."
302
- try:
303
- doc = nlp(text)
304
- entities = [(ent.text, ent.label_) for ent in doc.ents]
305
- if not entities:
306
- return "No named entities found."
307
- return "\n".join(f"{ent_text} -> {ent_label}" for ent_text, ent_label in entities)
308
- except Exception as e:
309
- logger.error(f"NER Error: {e}")
310
- return "Named Entity Recognition failed."
311
-
312
- ### Enhanced EDA ###
313
- def perform_enhanced_eda(df: pd.DataFrame) -> Tuple[str, Optional[alt.Chart], Optional[alt.Chart]]:
314
- """
315
- Perform a more advanced EDA given a DataFrame:
316
- - Show dataset info (columns, shape, numeric summary).
317
- - Generate a correlation heatmap (for numeric columns).
318
- - Generate distribution plots (histograms) for numeric columns.
319
- Returns (text_summary, correlation_chart, distribution_chart).
320
- """
321
- try:
322
- # Basic info
323
- columns_info = f"Columns: {list(df.columns)}"
324
- shape_info = f"Shape: {df.shape[0]} rows x {df.shape[1]} columns"
325
-
326
- # Use describe with "include='all'" to show all columns summary
327
- with pd.option_context("display.max_colwidth", 200, "display.max_rows", None):
328
- describe_info = df.describe(include="all").to_string()
329
-
330
- summary_text = (
331
- f"--- Enhanced EDA Summary ---\n"
332
- f"{columns_info}\n{shape_info}\n\n"
333
- f"Summary Statistics:\n{describe_info}\n"
334
- )
335
-
336
- # Correlation heatmap
337
- numeric_cols = df.select_dtypes(include="number")
338
- corr_chart = None
339
- if numeric_cols.shape[1] >= 2:
340
- corr = numeric_cols.corr()
341
- corr_melted = corr.reset_index().melt(id_vars="index")
342
- corr_melted.columns = ["Feature1", "Feature2", "Correlation"]
343
- corr_chart = (
344
- alt.Chart(corr_melted)
345
- .mark_rect()
346
- .encode(
347
- x="Feature1:O",
348
- y="Feature2:O",
349
- color="Correlation:Q",
350
- tooltip=["Feature1", "Feature2", "Correlation"]
351
- )
352
- .properties(width=400, height=400, title="Correlation Heatmap")
353
- )
354
-
355
- # Distribution plots (histograms) for numeric columns
356
- distribution_chart = None
357
- if numeric_cols.shape[1] >= 1:
358
- df_long = numeric_cols.melt(var_name='Column', value_name='Value')
359
- distribution_chart = (
360
- alt.Chart(df_long)
361
- .mark_bar()
362
- .encode(
363
- alt.X("Value:Q", bin=alt.Bin(maxbins=30)),
364
- alt.Y('count()'),
365
- alt.Facet('Column:N', columns=2),
366
- tooltip=["Value"]
367
- )
368
- .properties(
369
- title='Distribution of Numeric Columns',
370
- width=300,
371
- height=200
372
- )
373
- .interactive()
374
- )
375
-
376
- return summary_text, corr_chart, distribution_chart
377
-
378
- except Exception as e:
379
- logger.error(f"Enhanced EDA Error: {e}")
380
- return f"Enhanced EDA failed: {e}", None, None
381
-
382
- ### File Handling ###
383
- def read_uploaded_file(uploaded_file: Optional[gr.File]) -> str:
384
- """
385
- Reads the content of an uploaded file (txt, csv, xls, xlsx, pdf).
386
- Returns the extracted text or CSV-like content.
387
- """
388
- if uploaded_file is None:
389
- return ""
390
-
391
- file_name = uploaded_file.name
392
- file_ext = os.path.splitext(file_name)[1].lower()
393
-
394
- try:
395
- # For text
396
- if file_ext == ".txt":
397
- return uploaded_file.read().decode("utf-8")
398
-
399
- # For CSV
400
- elif file_ext == ".csv":
401
- return uploaded_file.read().decode("utf-8")
402
-
403
- # For Excel
404
- elif file_ext in [".xls", ".xlsx"]:
405
- # We'll just return empty here and parse it later into a DataFrame
406
- # because we can read the binary directly into pd.read_excel().
407
- # Or store as bytes for later use in EDA.
408
- return "EXCEL_FILE_PLACEHOLDER" # We'll handle it differently in EDA step
409
-
410
- # For PDF
411
- elif file_ext == ".pdf":
412
- pdf_reader = PyPDF2.PdfReader(uploaded_file)
413
- text_content = []
414
- for page in pdf_reader.pages:
415
- text_content.append(page.extract_text())
416
- return "\n".join(text_content)
417
-
418
- else:
419
- return f"Unsupported file format: {file_ext}"
420
- except Exception as e:
421
- logger.error(f"File read error: {e}")
422
- return f"Error reading file: {e}"
423
-
424
- def parse_excel_file(uploaded_file) -> pd.DataFrame:
425
- """
426
- Parse an Excel file into a pandas DataFrame.
427
- We assume the user wants the first sheet or we can guess.
428
- """
429
- try:
430
- # For Excel, we can do:
431
- df = pd.read_excel(uploaded_file, engine="openpyxl")
432
- return df
433
- except Exception as e:
434
- logger.error(f"Excel parsing error: {e}")
435
- raise
436
-
437
- def parse_csv_content(csv_content: str) -> pd.DataFrame:
438
- """
439
- Attempt to parse CSV content with both utf-8 and utf-8-sig to handle BOM issues.
440
- """
441
- from io import StringIO
442
- errors = []
443
- for encoding_try in ["utf-8", "utf-8-sig"]:
444
- try:
445
- df = pd.read_csv(StringIO(csv_content), encoding=encoding_try)
446
- return df
447
- except Exception as e:
448
- errors.append(f"Encoding {encoding_try} failed: {e}")
449
- error_msg = "Could not parse CSV content.\n" + "\n".join(errors)
450
- logger.error(error_msg)
451
- raise ValueError(error_msg)
452
-
453
- ### Gradio Interface ###
454
- with gr.Blocks() as demo:
455
- gr.Markdown("# ✨ Advanced Clinical Research Assistant with Enhanced EDA ✨")
456
- gr.Markdown("""
457
- Welcome to the **Enhanced** AI-Powered Clinical Assistant!
458
- - **Summarize** large blocks of clinical text.
459
- - **Predict** outcomes with a fine-tuned model.
460
- - **Translate** text between English & French.
461
- - **Perform Named Entity Recognition** with spaCy.
462
- - **Fetch** from PubMed, Crossref, Europe PMC.
463
- - **Generate** professional PDF reports.
464
- - **Perform Enhanced EDA** on CSV/Excel data with correlation heatmaps & distribution plots.
465
- """)
466
-
467
- # Inputs
468
- with gr.Row():
469
- text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or query...")
470
- file_input = gr.File(
471
- label="Upload File (txt/csv/xls/xlsx/pdf)",
472
- file_types=[".txt", ".csv", ".xls", ".xlsx", ".pdf"]
473
- )
474
-
475
- action = gr.Radio(
476
- [
477
- "Summarize",
478
- "Predict Outcome",
479
- "Generate Report",
480
- "Translate",
481
- "Perform Named Entity Recognition",
482
- "Perform Enhanced EDA",
483
- "Fetch Clinical Studies",
484
- "Fetch PubMed Articles (Legacy)",
485
- "Fetch PubMed by Query",
486
- "Fetch Crossref by Query",
487
- ],
488
- label="Select an Action",
489
- )
490
- translation_option = gr.Dropdown(
491
- choices=list(LANGUAGE_MAP.keys()),
492
- label="Translation Option",
493
- value="English to French"
494
- )
495
- query_params_input = gr.Textbox(
496
- label="Query Parameters (JSON Format)",
497
- placeholder='{"term": "cancer", "retmax": "5"}'
498
- )
499
- nct_id_input = gr.Textbox(label="NCT ID for Article Search")
500
- report_filename_input = gr.Textbox(
501
- label="Report Filename",
502
- placeholder="clinical_report.pdf",
503
- value="clinical_report.pdf"
504
- )
505
- export_format = gr.Dropdown(["None", "CSV", "JSON"], label="Export Format")
506
-
507
- # Outputs
508
- output_text = gr.Textbox(label="Output", lines=10)
509
-
510
- with gr.Row():
511
- output_chart = gr.Plot(label="Visualization 1")
512
- output_chart2 = gr.Plot(label="Visualization 2")
513
-
514
- output_file = gr.File(label="Generated File")
515
-
516
- submit_button = gr.Button("Submit")
517
-
518
- # Async function for handling actions
519
- async def handle_action(
520
- action: str,
521
- text: str,
522
- file_up: gr.File,
523
- translation_opt: str,
524
- query_params: str,
525
- nct_id: str,
526
- report_filename: str,
527
- export_format: str
528
- ) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]:
529
-
530
- # Read the uploaded file
531
- file_content = read_uploaded_file(file_up)
532
- combined_text = (text + "\n" + file_content).strip() if file_content else text
533
-
534
- # Branch by action
535
- if action == "Summarize":
536
- return summarize_text(combined_text), None, None, None
537
-
538
- elif action == "Predict Outcome":
539
- predictions = predict_outcome(combined_text)
540
- if isinstance(predictions, dict):
541
- chart = visualize_predictions(predictions)
542
- return json.dumps(predictions, indent=2), chart, None, None
543
- return predictions, None, None, None
544
-
545
- elif action == "Generate Report":
546
- file_path = generate_report(combined_text, filename=report_filename)
547
- msg = f"Report generated: {file_path}" if file_path else "Report generation failed."
548
- return msg, None, None, file_path
549
-
550
- elif action == "Translate":
551
- return translate_text(combined_text, translation_opt), None, None, None
552
-
553
- elif action == "Perform Named Entity Recognition":
554
- ner_result = perform_named_entity_recognition(combined_text)
555
- return ner_result, None, None, None
556
-
557
- elif action == "Perform Enhanced EDA":
558
- # We expect the user to either upload a CSV or Excel, or paste CSV content.
559
- if file_up is None and not combined_text:
560
- return "No data provided for EDA.", None, None, None
561
-
562
- # If Excel was uploaded
563
- if file_up and file_up.name.lower().endswith((".xls", ".xlsx")):
564
- try:
565
- df_excel = parse_excel_file(file_up)
566
- eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_excel)
567
- return eda_summary, corr_chart, dist_chart, None
568
- except Exception as e:
569
- return f"Excel EDA failed: {e}", None, None, None
570
-
571
- # If CSV was uploaded
572
- if file_up and file_up.name.lower().endswith(".csv"):
573
- try:
574
- df_csv = parse_csv_content(file_content)
575
- eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
576
- return eda_summary, corr_chart, dist_chart, None
577
- except Exception as e:
578
- return f"CSV EDA failed: {e}", None, None, None
579
-
580
- # If user just pasted CSV content (no file)
581
- if not file_up and "," in combined_text:
582
- try:
583
- df_csv = parse_csv_content(combined_text)
584
- eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
585
- return eda_summary, corr_chart, dist_chart, None
586
- except Exception as e:
587
- return f"CSV EDA failed: {e}", None, None, None
588
-
589
- # Otherwise, not supported
590
- return "No valid CSV/Excel data found for EDA.", None, None, None
591
-
592
- elif action == "Fetch Clinical Studies":
593
- if nct_id:
594
- result = await fetch_articles_by_nct_id(nct_id)
595
- elif query_params:
596
- result = await fetch_articles_by_query(query_params)
597
- else:
598
- return "Provide either an NCT ID or valid query parameters.", None, None, None
599
-
600
- articles = result.get("resultList", {}).get("result", [])
601
- if not articles:
602
- return "No articles found.", None, None, None
603
-
604
- formatted_results = "\n\n".join(
605
- f"Title: {a.get('title')}\nJournal: {a.get('journalTitle')} ({a.get('pubYear')})"
606
- for a in articles
607
- )
608
- return formatted_results, None, None, None
609
-
610
- elif action in ["Fetch PubMed Articles (Legacy)", "Fetch PubMed by Query"]:
611
- pubmed_result = await fetch_pubmed_by_query(query_params)
612
- xml_data = pubmed_result.get("result")
613
- if xml_data:
614
- articles = parse_pubmed_xml(xml_data)
615
- if not articles:
616
- return "No articles found.", None, None, None
617
- formatted = "\n\n".join(
618
- f"{a['Title']} - {a['Journal']} ({a['PublicationDate']})"
619
- for a in articles if a['Title']
620
- )
621
- return formatted if formatted else "No articles found.", None, None, None
622
- return "No articles found or error fetching data.", None, None, None
623
-
624
- elif action == "Fetch Crossref by Query":
625
- crossref_result = await fetch_crossref_by_query(query_params)
626
- items = crossref_result.get("message", {}).get("items", [])
627
- if not items:
628
- return "No results found.", None, None, None
629
- formatted = "\n\n".join(
630
- f"Title: {item.get('title', ['No title'])[0]}, DOI: {item.get('DOI')}"
631
- for item in items
632
- )
633
- return formatted, None, None, None
634
-
635
- return "Invalid action.", None, None, None
636
-
637
- submit_button.click(
638
- handle_action,
639
- inputs=[
640
- action,
641
- text_input,
642
- file_input,
643
- translation_option,
644
- query_params_input,
645
- nct_id_input,
646
- report_filename_input,
647
- export_format,
648
- ],
649
- outputs=[output_text, output_chart, output_chart2, output_file],
650
- )
651
-
652
- # Launch the Gradio app
653
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import asyncio
5
+ import xml.etree.ElementTree as ET
6
+ from typing import Any, Dict, Optional, Tuple, Union, List
7
+
8
+ import httpx
9
+ import gradio as gr
10
+ import torch
11
+ from dotenv import load_dotenv
12
+ from loguru import logger
13
+ from huggingface_hub import login
14
+ from openai import OpenAI
15
+ from reportlab.pdfgen import canvas
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForSequenceClassification,
19
+ MarianMTModel,
20
+ MarianTokenizer,
21
+ )
22
+ import pandas as pd
23
+ import altair as alt
24
+ import spacy
25
+ import spacy.cli
26
+ import PyPDF2
27
+ import io # For handling in-memory files for Excel
28
+
29
+ # Ensure spaCy model is downloaded
30
+ try:
31
+ nlp = spacy.load("en_core_web_sm")
32
+ except OSError:
33
+ logger.info("Downloading SpaCy 'en_core_web_sm' model...")
34
+ spacy.cli.download("en_core_web_sm")
35
+ nlp = spacy.load("en_core_web_sm")
36
+
37
+ # Logging
38
+ logger.add("error_logs.log", rotation="1 MB", level="ERROR")
39
+
40
+ # Load environment variables
41
+ load_dotenv()
42
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
43
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
44
+ ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL")
45
+
46
+ # Basic checks
47
+ if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY:
48
+ logger.error("Missing Hugging Face or OpenAI credentials.")
49
+ raise ValueError("Missing credentials for Hugging Face or OpenAI.")
50
+
51
+ # API endpoints
52
+ PUBMED_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
53
+ PUBMED_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
54
+ EUROPE_PMC_BASE_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
55
+
56
+ # Hugging Face login
57
+ login(HUGGINGFACE_TOKEN)
58
+
59
+ # Initialize OpenAI
60
+ client = OpenAI(api_key=OPENAI_API_KEY)
61
+
62
+ # Device setting
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ logger.info(f"Using device: {device}")
65
+
66
+ # Model settings
67
+ MODEL_NAME = "mgbam/bert-base-finetuned-mgbam"
68
+ try:
69
+ model = AutoModelForSequenceClassification.from_pretrained(
70
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
71
+ ).to(device)
72
+ tokenizer = AutoTokenizer.from_pretrained(
73
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
74
+ )
75
+ except Exception as e:
76
+ logger.error(f"Model load error: {e}")
77
+ raise
78
+
79
+ # Translation model settings
80
+ try:
81
+ translation_model_name = "Helsinki-NLP/opus-mt-en-fr"
82
+ translation_model = MarianMTModel.from_pretrained(
83
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
84
+ ).to(device)
85
+ translation_tokenizer = MarianTokenizer.from_pretrained(
86
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
87
+ )
88
+ except Exception as e:
89
+ logger.error(f"Translation model load error: {e}")
90
+ raise
91
+
92
+ LANGUAGE_MAP: Dict[str, Tuple[str, str]] = {
93
+ "English to French": ("en", "fr"),
94
+ "French to English": ("fr", "en"),
95
+ }
96
+
97
+ ### Utility Functions ###
98
+ def safe_json_parse(text: str) -> Union[Dict, None]:
99
+ """Safely parse JSON string into a Python dictionary."""
100
+ try:
101
+ return json.loads(text)
102
+ except json.JSONDecodeError as e:
103
+ logger.error(f"JSON parsing error: {e}")
104
+ return None
105
+
106
+ def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]:
107
+ """Parses PubMed XML data and returns a list of structured articles."""
108
+ root = ET.fromstring(xml_data)
109
+ articles = []
110
+ for article in root.findall(".//PubmedArticle"):
111
+ pmid = article.findtext(".//PMID")
112
+ title = article.findtext(".//ArticleTitle")
113
+ abstract = article.findtext(".//AbstractText")
114
+ journal = article.findtext(".//Journal/Title")
115
+ pub_date_elem = article.find(".//JournalIssue/PubDate")
116
+ pub_date = None
117
+ if pub_date_elem is not None:
118
+ year = pub_date_elem.findtext("Year")
119
+ month = pub_date_elem.findtext("Month")
120
+ day = pub_date_elem.findtext("Day")
121
+ if year and month and day:
122
+ pub_date = f"{year}-{month}-{day}"
123
+ else:
124
+ pub_date = year
125
+ articles.append({
126
+ "PMID": pmid,
127
+ "Title": title,
128
+ "Abstract": abstract,
129
+ "Journal": journal,
130
+ "PublicationDate": pub_date,
131
+ })
132
+ return articles
133
+
134
+ ### Asynchronous Functions for Europe PMC ###
135
+ async def fetch_articles_by_nct_id(nct_id: str) -> Dict[str, Any]:
136
+ params = {"query": nct_id, "format": "json"}
137
+ async with httpx.AsyncClient() as client_http:
138
+ try:
139
+ response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
140
+ response.raise_for_status()
141
+ return response.json()
142
+ except Exception as e:
143
+ logger.error(f"Error fetching articles for {nct_id}: {e}")
144
+ return {"error": str(e)}
145
+
146
+ async def fetch_articles_by_query(query_params: str) -> Dict[str, Any]:
147
+ parsed_params = safe_json_parse(query_params)
148
+ if not parsed_params or not isinstance(parsed_params, dict):
149
+ return {"error": "Invalid JSON."}
150
+ query_string = " AND ".join(f"{k}:{v}" for k, v in parsed_params.items())
151
+ params = {"query": query_string, "format": "json"}
152
+ async with httpx.AsyncClient() as client_http:
153
+ try:
154
+ response = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
155
+ response.raise_for_status()
156
+ return response.json()
157
+ except Exception as e:
158
+ logger.error(f"Error fetching articles: {e}")
159
+ return {"error": str(e)}
160
+
161
+ ### PubMed Integration ###
162
+ async def fetch_pubmed_by_query(query_params: str) -> Dict[str, Any]:
163
+ parsed_params = safe_json_parse(query_params)
164
+ if not parsed_params or not isinstance(parsed_params, dict):
165
+ return {"error": "Invalid JSON for PubMed."}
166
+
167
+ search_params = {
168
+ "db": "pubmed",
169
+ "retmode": "json",
170
+ "email": ENTREZ_EMAIL,
171
+ "retmax": parsed_params.get("retmax", "10"),
172
+ "term": parsed_params.get("term", ""),
173
+ }
174
+
175
+ async with httpx.AsyncClient() as client_http:
176
+ try:
177
+ search_response = await client_http.get(PUBMED_SEARCH_URL, params=search_params)
178
+ search_response.raise_for_status()
179
+ search_data = search_response.json()
180
+ id_list = search_data.get("esearchresult", {}).get("idlist", [])
181
+ if not id_list:
182
+ return {"result": ""}
183
+
184
+ fetch_params = {
185
+ "db": "pubmed",
186
+ "id": ",".join(id_list),
187
+ "retmode": "xml",
188
+ "email": ENTREZ_EMAIL,
189
+ }
190
+ fetch_response = await client_http.get(PUBMED_FETCH_URL, params=fetch_params)
191
+ fetch_response.raise_for_status()
192
+ return {"result": fetch_response.text}
193
+ except Exception as e:
194
+ logger.error(f"Error fetching PubMed articles: {e}")
195
+ return {"error": str(e)}
196
+
197
+ ### Crossref Integration ###
198
+ async def fetch_crossref_by_query(query_params: str) -> Dict[str, Any]:
199
+ parsed_params = safe_json_parse(query_params)
200
+ if not parsed_params or not isinstance(parsed_params, dict):
201
+ return {"error": "Invalid JSON for Crossref."}
202
+ CROSSREF_API_URL = "https://api.crossref.org/works"
203
+ async with httpx.AsyncClient() as client_http:
204
+ try:
205
+ response = await client_http.get(CROSSREF_API_URL, params=parsed_params)
206
+ response.raise_for_status()
207
+ return response.json()
208
+ except Exception as e:
209
+ logger.error(f"Error fetching Crossref data: {e}")
210
+ return {"error": str(e)}
211
+
212
+ ### Core Functions ###
213
+ def summarize_text(text: str) -> str:
214
+ """Summarize text using OpenAI."""
215
+ if not text.strip():
216
+ return "No text provided for summarization."
217
+ try:
218
+ response = client.chat.completions.create(
219
+ model="gpt-3.5-turbo",
220
+ messages=[{"role": "user", "content": f"Summarize the following clinical data:\n{text}"}],
221
+ max_tokens=200,
222
+ temperature=0.7,
223
+ )
224
+ return response.choices[0].message.content.strip()
225
+ except Exception as e:
226
+ logger.error(f"Summarization Error: {e}")
227
+ return "Summarization failed."
228
+
229
+ def predict_outcome(text: str) -> Union[Dict[str, float], str]:
230
+ """Predict outcomes (classification) using a fine-tuned model."""
231
+ if not text.strip():
232
+ return "No text provided for prediction."
233
+ try:
234
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
235
+ inputs = {k: v.to(device) for k, v in inputs.items()}
236
+ with torch.no_grad():
237
+ outputs = model(**inputs)
238
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
239
+ return {f"Label {i+1}": float(prob.item()) for i, prob in enumerate(probabilities)}
240
+ except Exception as e:
241
+ logger.error(f"Prediction Error: {e}")
242
+ return "Prediction failed."
243
+
244
+ def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]:
245
+ """Generate a PDF report from the given text."""
246
+ try:
247
+ if not text.strip():
248
+ logger.warning("No text provided for the report.")
249
+ c = canvas.Canvas(filename)
250
+ c.drawString(100, 750, "Clinical Research Report")
251
+ lines = text.split("\n")
252
+ y = 730
253
+ for line in lines:
254
+ if y < 50:
255
+ c.showPage()
256
+ y = 750
257
+ c.drawString(100, y, line)
258
+ y -= 15
259
+ c.save()
260
+ logger.info(f"Report generated: {filename}")
261
+ return filename
262
+ except Exception as e:
263
+ logger.error(f"Report Generation Error: {e}")
264
+ return None
265
+
266
+ def visualize_predictions(predictions: Dict[str, float]) -> Optional[alt.Chart]:
267
+ """Visualize model prediction probabilities using Altair."""
268
+ try:
269
+ data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"])
270
+ chart = (
271
+ alt.Chart(data)
272
+ .mark_bar()
273
+ .encode(
274
+ x=alt.X("Label:N", sort=None),
275
+ y="Probability:Q",
276
+ tooltip=["Label", "Probability"],
277
+ )
278
+ .properties(title="Prediction Probabilities", width=500, height=300)
279
+ )
280
+ return chart
281
+ except Exception as e:
282
+ logger.error(f"Visualization Error: {e}")
283
+ return None
284
+
285
+ def translate_text(text: str, translation_option: str) -> str:
286
+ """Translate text between English and French."""
287
+ if not text.strip():
288
+ return "No text provided for translation."
289
+ try:
290
+ if translation_option not in LANGUAGE_MAP:
291
+ return "Unsupported translation option."
292
+ inputs = translation_tokenizer(text, return_tensors="pt", padding=True).to(device)
293
+ translated_tokens = translation_model.generate(**inputs)
294
+ return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
295
+ except Exception as e:
296
+ logger.error(f"Translation Error: {e}")
297
+ return "Translation failed."
298
+
299
+ def perform_named_entity_recognition(text: str) -> str:
300
+ """Perform Named Entity Recognition (NER) using spaCy."""
301
+ if not text.strip():
302
+ return "No text provided for NER."
303
+ try:
304
+ doc = nlp(text)
305
+ entities = [(ent.text, ent.label_) for ent in doc.ents]
306
+ if not entities:
307
+ return "No named entities found."
308
+ return "\n".join(f"{ent_text} -> {ent_label}" for ent_text, ent_label in entities)
309
+ except Exception as e:
310
+ logger.error(f"NER Error: {e}")
311
+ return "Named Entity Recognition failed."
312
+
313
+ ### Enhanced EDA ###
314
+ def perform_enhanced_eda(df: pd.DataFrame) -> Tuple[str, Optional[alt.Chart], Optional[alt.Chart]]:
315
+ """
316
+ Perform a more advanced EDA given a DataFrame:
317
+ - Show dataset info (columns, shape, numeric summary).
318
+ - Generate a correlation heatmap (for numeric columns).
319
+ - Generate distribution plots (histograms) for numeric columns.
320
+ Returns (text_summary, correlation_chart, distribution_chart).
321
+ """
322
+ try:
323
+ # Basic info
324
+ columns_info = f"Columns: {list(df.columns)}"
325
+ shape_info = f"Shape: {df.shape[0]} rows x {df.shape[1]} columns"
326
+
327
+ # Describe with include="all" to show all columns
328
+ with pd.option_context("display.max_colwidth", 200, "display.max_rows", None):
329
+ describe_info = df.describe(include="all").to_string()
330
+
331
+ summary_text = (
332
+ f"--- Enhanced EDA Summary ---\n"
333
+ f"{columns_info}\n{shape_info}\n\n"
334
+ f"Summary Statistics:\n{describe_info}\n"
335
+ )
336
+
337
+ # Correlation heatmap (if at least 2 numeric columns)
338
+ numeric_cols = df.select_dtypes(include="number")
339
+ corr_chart = None
340
+ if numeric_cols.shape[1] >= 2:
341
+ corr = numeric_cols.corr()
342
+ corr_melted = corr.reset_index().melt(id_vars="index")
343
+ corr_melted.columns = ["Feature1", "Feature2", "Correlation"]
344
+ corr_chart = (
345
+ alt.Chart(corr_melted)
346
+ .mark_rect()
347
+ .encode(
348
+ x="Feature1:O",
349
+ y="Feature2:O",
350
+ color="Correlation:Q",
351
+ tooltip=["Feature1", "Feature2", "Correlation"]
352
+ )
353
+ .properties(width=400, height=400, title="Correlation Heatmap")
354
+ )
355
+
356
+ # Distribution plots (histograms) for numeric columns
357
+ distribution_chart = None
358
+ if numeric_cols.shape[1] >= 1:
359
+ df_long = numeric_cols.melt(var_name='Column', value_name='Value')
360
+ distribution_chart = (
361
+ alt.Chart(df_long)
362
+ .mark_bar()
363
+ .encode(
364
+ alt.X("Value:Q", bin=alt.Bin(maxbins=30)),
365
+ alt.Y('count()'),
366
+ alt.Facet('Column:N', columns=2),
367
+ tooltip=["Value"]
368
+ )
369
+ .properties(
370
+ title='Distribution of Numeric Columns',
371
+ width=300,
372
+ height=200
373
+ )
374
+ .interactive()
375
+ )
376
+
377
+ return summary_text, corr_chart, distribution_chart
378
+
379
+ except Exception as e:
380
+ logger.error(f"Enhanced EDA Error: {e}")
381
+ return f"Enhanced EDA failed: {e}", None, None
382
+
383
+ ### File Handling ###
384
+
385
+ def read_uploaded_file(uploaded_file: Optional[gr.File]) -> str:
386
+ """
387
+ Reads the content of an uploaded file (txt, csv, xls, xlsx, pdf).
388
+ Returns the extracted text or CSV-like content for non-Excel files.
389
+ For Excel, we return a placeholder string; we'll handle it later.
390
+ """
391
+ if uploaded_file is None:
392
+ return ""
393
+
394
+ file_name = uploaded_file.name
395
+ file_ext = os.path.splitext(file_name)[1].lower()
396
+
397
+ try:
398
+ # TXT
399
+ if file_ext == ".txt":
400
+ return uploaded_file.read().decode("utf-8")
401
+
402
+ # CSV
403
+ elif file_ext == ".csv":
404
+ return uploaded_file.read().decode("utf-8")
405
+
406
+ # Excel
407
+ elif file_ext in [".xls", ".xlsx"]:
408
+ # We won't parse here; we'll parse in parse_excel_file(...)
409
+ # Return a placeholder so we know an Excel file was uploaded
410
+ return "EXCEL_FILE_PLACEHOLDER"
411
+
412
+ # PDF
413
+ elif file_ext == ".pdf":
414
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
415
+ text_content = []
416
+ for page in pdf_reader.pages:
417
+ text_content.append(page.extract_text())
418
+ return "\n".join(text_content)
419
+
420
+ else:
421
+ return f"Unsupported file format: {file_ext}"
422
+ except Exception as e:
423
+ logger.error(f"File read error: {e}")
424
+ return f"Error reading file: {e}"
425
+
426
+ def parse_excel_file(uploaded_file: gr.File) -> pd.DataFrame:
427
+ """
428
+ Parse an Excel file into a pandas DataFrame using raw bytes.
429
+ This avoids the NamedString error from calling .read() on a Gradio file.
430
+ """
431
+ try:
432
+ excel_bytes = uploaded_file.data # raw file content in bytes
433
+ df = pd.read_excel(io.BytesIO(excel_bytes), engine="openpyxl")
434
+ return df
435
+ except Exception as e:
436
+ logger.error(f"Excel parsing error: {e}")
437
+ raise ValueError(f"Excel parsing error: {e}")
438
+
439
+ def parse_csv_content(csv_content: str) -> pd.DataFrame:
440
+ """
441
+ Attempt to parse CSV content with both utf-8 and utf-8-sig
442
+ to handle BOM issues or encoding complexities.
443
+ """
444
+ from io import StringIO
445
+ errors = []
446
+ for encoding_try in ["utf-8", "utf-8-sig"]:
447
+ try:
448
+ df = pd.read_csv(StringIO(csv_content), encoding=encoding_try)
449
+ return df
450
+ except Exception as e:
451
+ errors.append(f"Encoding {encoding_try} failed: {e}")
452
+ error_msg = "Could not parse CSV content.\n" + "\n".join(errors)
453
+ logger.error(error_msg)
454
+ raise ValueError(error_msg)
455
+
456
+ ### Gradio Interface ###
457
+ with gr.Blocks() as demo:
458
+ gr.Markdown("# Advanced Clinical Research Assistant with Enhanced EDA ✨")
459
+ gr.Markdown("""
460
+ Welcome to the **Enhanced** AI-Powered Clinical Assistant!
461
+ - **Summarize** large blocks of clinical text.
462
+ - **Predict** outcomes with a fine-tuned model.
463
+ - **Translate** text (English ↔ French).
464
+ - **Perform Named Entity Recognition** (spaCy).
465
+ - **Fetch** from PubMed, Crossref, Europe PMC.
466
+ - **Generate** professional PDF reports.
467
+ - **Perform Enhanced EDA** on CSV/Excel data (correlation heatmaps + distribution plots).
468
+ """)
469
+
470
+ # Inputs
471
+ with gr.Row():
472
+ text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or query...")
473
+ file_input = gr.File(
474
+ label="Upload File (txt/csv/xls/xlsx/pdf)",
475
+ file_types=[".txt", ".csv", ".xls", ".xlsx", ".pdf"]
476
+ )
477
+
478
+ action = gr.Radio(
479
+ [
480
+ "Summarize",
481
+ "Predict Outcome",
482
+ "Generate Report",
483
+ "Translate",
484
+ "Perform Named Entity Recognition",
485
+ "Perform Enhanced EDA",
486
+ "Fetch Clinical Studies",
487
+ "Fetch PubMed Articles (Legacy)",
488
+ "Fetch PubMed by Query",
489
+ "Fetch Crossref by Query",
490
+ ],
491
+ label="Select an Action",
492
+ )
493
+ translation_option = gr.Dropdown(
494
+ choices=list(LANGUAGE_MAP.keys()),
495
+ label="Translation Option",
496
+ value="English to French"
497
+ )
498
+ query_params_input = gr.Textbox(
499
+ label="Query Parameters (JSON Format)",
500
+ placeholder='{"term": "cancer", "retmax": "5"}'
501
+ )
502
+ nct_id_input = gr.Textbox(label="NCT ID for Article Search")
503
+ report_filename_input = gr.Textbox(
504
+ label="Report Filename",
505
+ placeholder="clinical_report.pdf",
506
+ value="clinical_report.pdf"
507
+ )
508
+ export_format = gr.Dropdown(["None", "CSV", "JSON"], label="Export Format")
509
+
510
+ # Outputs
511
+ output_text = gr.Textbox(label="Output", lines=10)
512
+
513
+ with gr.Row():
514
+ output_chart = gr.Plot(label="Visualization 1")
515
+ output_chart2 = gr.Plot(label="Visualization 2")
516
+
517
+ output_file = gr.File(label="Generated File")
518
+
519
+ submit_button = gr.Button("Submit")
520
+
521
+ # Async function for handling actions
522
+ async def handle_action(
523
+ action: str,
524
+ text: str,
525
+ file_up: gr.File,
526
+ translation_opt: str,
527
+ query_params: str,
528
+ nct_id: str,
529
+ report_filename: str,
530
+ export_format: str
531
+ ) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]:
532
+
533
+ # 1) Read the uploaded file (if any) -> returns a string or placeholder
534
+ file_content = read_uploaded_file(file_up)
535
+
536
+ # 2) Combine user text with file text if needed
537
+ combined_text = (text + "\n" + file_content).strip() if file_content else text
538
+
539
+ ### Branch by action ###
540
+ if action == "Summarize":
541
+ return summarize_text(combined_text), None, None, None
542
+
543
+ elif action == "Predict Outcome":
544
+ predictions = predict_outcome(combined_text)
545
+ if isinstance(predictions, dict):
546
+ chart = visualize_predictions(predictions)
547
+ return json.dumps(predictions, indent=2), chart, None, None
548
+ return predictions, None, None, None
549
+
550
+ elif action == "Generate Report":
551
+ file_path = generate_report(combined_text, filename=report_filename)
552
+ msg = f"Report generated: {file_path}" if file_path else "Report generation failed."
553
+ return msg, None, None, file_path
554
+
555
+ elif action == "Translate":
556
+ return translate_text(combined_text, translation_opt), None, None, None
557
+
558
+ elif action == "Perform Named Entity Recognition":
559
+ ner_result = perform_named_entity_recognition(combined_text)
560
+ return ner_result, None, None, None
561
+
562
+ elif action == "Perform Enhanced EDA":
563
+ # Ensure some data is provided
564
+ if not file_up and not combined_text:
565
+ return "No data provided for EDA.", None, None, None
566
+
567
+ # If the user uploaded an Excel file
568
+ if file_up and file_up.name.lower().endswith((".xls", ".xlsx")):
569
+ try:
570
+ df_excel = parse_excel_file(file_up)
571
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_excel)
572
+ return eda_summary, corr_chart, dist_chart, None
573
+ except Exception as e:
574
+ return f"Excel EDA failed: {e}", None, None, None
575
+
576
+ # If the user uploaded a CSV
577
+ if file_up and file_up.name.lower().endswith(".csv"):
578
+ try:
579
+ df_csv = parse_csv_content(file_content)
580
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
581
+ return eda_summary, corr_chart, dist_chart, None
582
+ except Exception as e:
583
+ return f"CSV EDA failed: {e}", None, None, None
584
+
585
+ # If no file but possibly CSV text in the text box
586
+ if not file_up and "," in combined_text:
587
+ try:
588
+ df_csv = parse_csv_content(combined_text)
589
+ eda_summary, corr_chart, dist_chart = perform_enhanced_eda(df_csv)
590
+ return eda_summary, corr_chart, dist_chart, None
591
+ except Exception as e:
592
+ return f"CSV EDA failed: {e}", None, None, None
593
+
594
+ return "No valid CSV/Excel data found for EDA.", None, None, None
595
+
596
+ elif action == "Fetch Clinical Studies":
597
+ if nct_id:
598
+ result = await fetch_articles_by_nct_id(nct_id)
599
+ elif query_params:
600
+ result = await fetch_articles_by_query(query_params)
601
+ else:
602
+ return "Provide either an NCT ID or valid query parameters.", None, None, None
603
+
604
+ articles = result.get("resultList", {}).get("result", [])
605
+ if not articles:
606
+ return "No articles found.", None, None, None
607
+
608
+ formatted_results = "\n\n".join(
609
+ f"Title: {a.get('title')}\nJournal: {a.get('journalTitle')} ({a.get('pubYear')})"
610
+ for a in articles
611
+ )
612
+ return formatted_results, None, None, None
613
+
614
+ elif action in ["Fetch PubMed Articles (Legacy)", "Fetch PubMed by Query"]:
615
+ pubmed_result = await fetch_pubmed_by_query(query_params)
616
+ xml_data = pubmed_result.get("result")
617
+ if xml_data:
618
+ articles = parse_pubmed_xml(xml_data)
619
+ if not articles:
620
+ return "No articles found.", None, None, None
621
+ formatted = "\n\n".join(
622
+ f"{a['Title']} - {a['Journal']} ({a['PublicationDate']})"
623
+ for a in articles if a['Title']
624
+ )
625
+ return formatted if formatted else "No articles found.", None, None, None
626
+ return "No articles found or error fetching data.", None, None, None
627
+
628
+ elif action == "Fetch Crossref by Query":
629
+ crossref_result = await fetch_crossref_by_query(query_params)
630
+ items = crossref_result.get("message", {}).get("items", [])
631
+ if not items:
632
+ return "No results found.", None, None, None
633
+ formatted = "\n\n".join(
634
+ f"Title: {item.get('title', ['No title'])[0]}, DOI: {item.get('DOI')}"
635
+ for item in items
636
+ )
637
+ return formatted, None, None, None
638
+
639
+ # Default fallback
640
+ return "Invalid action.", None, None, None
641
+
642
+ submit_button.click(
643
+ handle_action,
644
+ inputs=[
645
+ action,
646
+ text_input,
647
+ file_input,
648
+ translation_option,
649
+ query_params_input,
650
+ nct_id_input,
651
+ report_filename_input,
652
+ export_format,
653
+ ],
654
+ outputs=[output_text, output_chart, output_chart2, output_file],
655
+ )
656
+
657
+ # Launch the Gradio app
658
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)