mgbam commited on
Commit
a791ee6
·
verified ·
1 Parent(s): 4290ea7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -519
app.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  from dotenv import load_dotenv
13
  from loguru import logger
14
  from huggingface_hub import login
15
- import openai
16
  from reportlab.pdfgen import canvas
17
  from transformers import (
18
  AutoTokenizer,
@@ -30,17 +30,24 @@ import PyPDF2
30
  # 1) ENVIRONMENT & LOGGING #
31
  ###############################################################################
32
 
33
- # Initialize Logging
 
 
 
 
 
 
 
 
34
  logger.add("error_logs.log", rotation="1 MB", level="ERROR")
35
 
36
  # Load environment variables
37
  load_dotenv()
38
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
39
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
40
- BIOPORTAL_API_KEY = os.getenv("BIOPORTAL_API_KEY")
41
  ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL")
42
 
43
- # Validate API Keys
44
  if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY:
45
  logger.error("Missing Hugging Face or OpenAI credentials.")
46
  raise ValueError("Missing credentials for Hugging Face or OpenAI.")
@@ -52,52 +59,42 @@ if not BIOPORTAL_API_KEY:
52
  # Hugging Face login
53
  login(HUGGINGFACE_TOKEN)
54
 
55
- # OpenAI Initialization
56
- openai.api_key = OPENAI_API_KEY
57
 
58
- # Device Configuration
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  logger.info(f"Using device: {device}")
61
 
62
- # Ensure spaCy model is downloaded (English Core Web)
63
- try:
64
- nlp = spacy.load("en_core_web_sm")
65
- except OSError:
66
- logger.info("Downloading SpaCy 'en_core_web_sm' model...")
67
- spacy.cli.download("en_core_web_sm")
68
- nlp = spacy.load("en_core_web_sm")
69
-
70
  ###############################################################################
71
  # 2) HUGGING FACE & TRANSLATION MODEL SETUP #
72
  ###############################################################################
73
 
74
- # Outcome Prediction Model (Fine-Tuned BERT)
75
- OUTCOME_MODEL_NAME = "mgbam/bert-base-finetuned-mgbam"
76
  try:
77
- outcome_model = AutoModelForSequenceClassification.from_pretrained(
78
- OUTCOME_MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
79
  ).to(device)
80
- outcome_tokenizer = AutoTokenizer.from_pretrained(
81
- OUTCOME_MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
82
  )
83
  except Exception as e:
84
- logger.error(f"Outcome Model load error: {e}")
85
  raise
86
 
87
- # Translation Model (English ↔ French)
88
- TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-en-fr"
89
  try:
 
90
  translation_model = MarianMTModel.from_pretrained(
91
- TRANSLATION_MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
92
  ).to(device)
93
  translation_tokenizer = MarianTokenizer.from_pretrained(
94
- TRANSLATION_MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
95
  )
96
  except Exception as e:
97
  logger.error(f"Translation model load error: {e}")
98
  raise
99
 
100
- # Language Mapping for Translation
101
  LANGUAGE_MAP: Dict[str, Tuple[str, str]] = {
102
  "English to French": ("en", "fr"),
103
  "French to English": ("fr", "en"),
@@ -153,130 +150,33 @@ def parse_pubmed_xml(xml_data: str) -> List[Dict[str, Any]]:
153
  })
154
  return articles
155
 
156
- ###############################################################################
157
- # 5) ASYNC FETCH FUNCTIONS #
158
- ###############################################################################
159
-
160
- async def fetch_articles_by_nct_id(nct_id: str) -> Dict[str, Any]:
161
- """Fetch articles from Europe PMC using NCT ID."""
162
- params = {"query": nct_id, "format": "json"}
163
- async with httpx.AsyncClient() as client_http:
164
- try:
165
- resp = await client_http.get(EUROPE_PMC_BASE_URL, params=params)
166
- resp.raise_for_status()
167
- return resp.json()
168
- except Exception as e:
169
- logger.error(f"Error fetching articles for {nct_id}: {e}")
170
- return {"error": str(e)}
171
-
172
- async def fetch_articles_by_query(query_params: str) -> Dict[str, Any]:
173
- """Fetch articles from Europe PMC based on query parameters."""
174
- parsed_params = safe_json_parse(query_params)
175
- if not parsed_params or not isinstance(parsed_params, dict):
176
- return {"error": "Invalid JSON."}
177
- query_string = " AND ".join(f"{k}:{v}" for k, v in parsed_params.items())
178
- req_params = {"query": query_string, "format": "json"}
179
- async with httpx.AsyncClient() as client_http:
180
- try:
181
- resp = await client_http.get(EUROPE_PMC_BASE_URL, params=req_params)
182
- resp.raise_for_status()
183
- return resp.json()
184
- except Exception as e:
185
- logger.error(f"Error fetching Europe PMC articles: {e}")
186
- return {"error": str(e)}
187
-
188
- async def fetch_pubmed_by_query(query_params: str) -> Dict[str, Any]:
189
- """Fetch articles from PubMed based on query parameters."""
190
- parsed_params = safe_json_parse(query_params)
191
- if not parsed_params or not isinstance(parsed_params, dict):
192
- return {"error": "Invalid JSON for PubMed."}
193
-
194
- search_params = {
195
- "db": "pubmed",
196
- "retmode": "json",
197
- "email": ENTREZ_EMAIL,
198
- "retmax": parsed_params.get("retmax", "10"),
199
- "term": parsed_params.get("term", ""),
200
- }
201
- async with httpx.AsyncClient() as client_http:
202
- try:
203
- # Search PubMed
204
- search_resp = await client_http.get(PUBMED_SEARCH_URL, params=search_params)
205
- search_resp.raise_for_status()
206
- data = search_resp.json()
207
- id_list = data.get("esearchresult", {}).get("idlist", [])
208
- if not id_list:
209
- return {"result": ""}
210
-
211
- # Fetch PubMed Articles
212
- fetch_params = {
213
- "db": "pubmed",
214
- "id": ",".join(id_list),
215
- "retmode": "xml",
216
- "email": ENTREZ_EMAIL,
217
- }
218
- fetch_resp = await client_http.get(PUBMED_FETCH_URL, params=fetch_params)
219
- fetch_resp.raise_for_status()
220
- return {"result": fetch_resp.text}
221
- except Exception as e:
222
- logger.error(f"Error fetching PubMed articles: {e}")
223
- return {"error": str(e)}
224
-
225
- async def fetch_crossref_by_query(query_params: str) -> Dict[str, Any]:
226
- """Fetch articles from Crossref based on query parameters."""
227
- parsed_params = safe_json_parse(query_params)
228
- if not parsed_params or not isinstance(parsed_params, dict):
229
- return {"error": "Invalid JSON for Crossref."}
230
- async with httpx.AsyncClient() as client_http:
231
- try:
232
- resp = await client_http.get(CROSSREF_API_URL, params=parsed_params)
233
- resp.raise_for_status()
234
- return resp.json()
235
- except Exception as e:
236
- logger.error(f"Error fetching Crossref data: {e}")
237
- return {"error": str(e)}
238
-
239
- async def fetch_bioportal_by_query(query_params: str) -> Dict[str, Any]:
240
- """
241
- Fetch ontology data from BioPortal based on query parameters.
242
- Expects JSON like: {"q": "cancer"}
243
- """
244
- if not BIOPORTAL_API_KEY:
245
- return {"error": "No BioPortal API Key set."}
246
- parsed_params = safe_json_parse(query_params)
247
- if not parsed_params or not isinstance(parsed_params, dict):
248
- return {"error": "Invalid JSON for BioPortal."}
249
-
250
- search_term = parsed_params.get("q", "")
251
- if not search_term:
252
- return {"error": "No 'q' found in JSON. Provide a search term."}
253
-
254
- url = f"{BIOPORTAL_API_BASE}/search"
255
- headers = {"Authorization": f"apikey token={BIOPORTAL_API_KEY}"}
256
- req_params = {"q": search_term}
257
-
258
- async with httpx.AsyncClient() as client_http:
259
- try:
260
- resp = await client_http.get(url, params=req_params, headers=headers)
261
- resp.raise_for_status()
262
- return resp.json()
263
- except Exception as e:
264
- logger.error(f"Error fetching BioPortal data: {e}")
265
- return {"error": str(e)}
266
 
267
  ###############################################################################
268
  # 6) CORE FUNCTIONS #
269
  ###############################################################################
270
 
271
  def summarize_text(text: str) -> str:
272
- """Summarize clinical text using OpenAI GPT-3.5."""
273
  if not text.strip():
274
  return "No text provided for summarization."
275
  try:
276
- response = openai.ChatCompletion.create(
277
  model="gpt-3.5-turbo",
278
  messages=[{"role": "user", "content": f"Summarize this clinical data:\n{text}"}],
279
- max_tokens=500,
280
  temperature=0.7,
281
  )
282
  return response.choices[0].message.content.strip()
@@ -284,67 +184,19 @@ def summarize_text(text: str) -> str:
284
  logger.error(f"Summarization error: {e}")
285
  return "Summarization failed."
286
 
287
- def predict_outcome(text: str) -> Union[Dict[str, float], str]:
288
- """Predict outcomes using a fine-tuned Hugging Face BERT model."""
289
- if not text.strip():
290
- return "No text provided for prediction."
291
- try:
292
- inputs = outcome_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
293
- inputs = {k: v.to(device) for k, v in inputs.items()}
294
- with torch.no_grad():
295
- outputs = outcome_model(**inputs)
296
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
297
- labels = outcome_model.config.id2label
298
- return {labels[i]: float(prob.item()) for i, prob in enumerate(probabilities)}
299
- except Exception as e:
300
- logger.error(f"Prediction error: {e}")
301
- return "Prediction failed."
302
-
303
- def translate_text(text: str, translation_option: str) -> str:
304
- """Translate text between English and French using MarianMT."""
305
- if not text.strip():
306
- return "No text provided for translation."
307
- try:
308
- if translation_option not in LANGUAGE_MAP:
309
- return "Unsupported translation option."
310
- inputs = translation_tokenizer(text, return_tensors="pt", padding=True).to(device)
311
- translated_tokens = translation_model.generate(**inputs)
312
- translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
313
- return translated_text
314
- except Exception as e:
315
- logger.error(f"Translation error: {e}")
316
- return "Translation failed."
317
-
318
- def perform_named_entity_recognition(text: str) -> str:
319
- """Perform Named Entity Recognition using spaCy."""
320
- if not text.strip():
321
- return "No text provided for NER."
322
- try:
323
- doc = nlp(text)
324
- entities = [(ent.text, ent.label_) for ent in doc.ents]
325
- if not entities:
326
- return "No named entities found."
327
- return "\n".join(f"{t} -> {lbl}" for t, lbl in entities)
328
- except Exception as e:
329
- logger.error(f"NER error: {e}")
330
- return "NER failed."
331
-
332
  def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]:
333
- """Generate a professional PDF report from the text using ReportLab."""
334
  try:
335
  if not text.strip():
336
  logger.warning("No text provided for the report.")
337
  c = canvas.Canvas(filename)
338
- c.setFont("Helvetica-Bold", 16)
339
- c.drawString(100, 800, "Clinical Research Report")
340
- c.setFont("Helvetica", 12)
341
  lines = text.split("\n")
342
- y = 780
343
  for line in lines:
344
  if y < 50:
345
  c.showPage()
346
- c.setFont("Helvetica", 12)
347
- y = 800
348
  c.drawString(100, y, line)
349
  y -= 15
350
  c.save()
@@ -355,7 +207,7 @@ def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optiona
355
  return None
356
 
357
  def visualize_predictions(predictions: Dict[str, float]) -> alt.Chart:
358
- """Visualize prediction probabilities using Altair."""
359
  data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"])
360
  chart = (
361
  alt.Chart(data)
@@ -369,345 +221,66 @@ def visualize_predictions(predictions: Dict[str, float]) -> alt.Chart:
369
  )
370
  return chart
371
 
372
- def fetch_web_search(query: str) -> str:
373
- """Use OpenAI to perform a web search and provide explanations."""
374
- if not query.strip():
375
- return "No query provided for web search."
376
- try:
377
- response = openai.ChatCompletion.create(
378
- model="gpt-3.5-turbo",
379
- messages=[
380
- {"role": "system", "content": "You are a helpful assistant that provides detailed explanations based on the latest research."},
381
- {"role": "user", "content": f"Explain the following query using the latest research: {query}"},
382
- ],
383
- max_tokens=700,
384
- temperature=0.7,
385
- )
386
- return response.choices[0].message.content.strip()
387
- except Exception as e:
388
- logger.error(f"Web search error: {e}")
389
- return "Web search failed."
390
-
391
- ###############################################################################
392
- # 7) FILE PARSING (TXT, PDF, CSV, XLS) #
393
- ###############################################################################
394
-
395
- def parse_pdf_file_as_str(file_up: gr.File) -> str:
396
- """Extract text from a PDF file using PyPDF2."""
397
- try:
398
- pdf_bytes = file_up.read()
399
- reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
400
- return "\n".join(page.extract_text() or "" for page in reader.pages)
401
- except Exception as e:
402
- logger.error(f"PDF parse error: {e}")
403
- return "Failed to extract text from PDF."
404
-
405
- def parse_text_file_as_str(file_up: gr.File) -> str:
406
- """Extract text from a TXT file."""
407
- try:
408
- return file_up.read().decode("utf-8", errors="replace")
409
- except Exception as e:
410
- logger.error(f"TXT parse error: {e}")
411
- return "Failed to extract text from TXT file."
412
-
413
- def parse_csv_file_to_df(file_up: gr.File) -> pd.DataFrame:
414
- """Parse CSV file into a pandas DataFrame with multiple encoding attempts."""
415
- try:
416
- return pd.read_csv(io.StringIO(file_up.read().decode("utf-8", errors="replace")))
417
- except UnicodeDecodeError:
418
- try:
419
- return pd.read_csv(io.StringIO(file_up.read().decode("latin1", errors="replace")))
420
- except Exception as e:
421
- logger.error(f"CSV parse error: {e}")
422
- return pd.DataFrame()
423
- except Exception as e:
424
- logger.error(f"CSV parse error: {e}")
425
- return pd.DataFrame()
426
-
427
- def parse_excel_file_to_df(file_up: gr.File) -> pd.DataFrame:
428
- """Parse Excel file into a pandas DataFrame."""
429
- try:
430
- return pd.read_excel(io.BytesIO(file_up.read()), engine="openpyxl")
431
- except Exception as e:
432
- logger.error(f"Excel parse error: {e}")
433
- return pd.DataFrame()
434
-
435
  ###############################################################################
436
- # 8) BUILDING THE GRADIO APP #
437
  ###############################################################################
438
 
439
- def format_articles(articles: List[Dict[str, Any]]) -> str:
440
- """Format fetched articles into a readable string."""
441
- formatted = ""
442
- for article in articles:
443
- title = article.get("title", "No Title")
444
- journal = article.get("journalTitle", "No Journal")
445
- pub_year = article.get("pubYear", "No Year")
446
- formatted += f"Title: {title}\nJournal: {journal} ({pub_year})\n\n"
447
- return formatted.strip()
448
-
449
- def format_bioportal_results(collection: List[Dict[str, Any]]) -> str:
450
- """Format BioPortal results into a readable string."""
451
- formatted = ""
452
- for col in collection:
453
- label = col.get("prefLabel", "No Label")
454
- ontology = col.get("ontology", {}).get("name", "No Ontology")
455
- formatted += f"Label: {label}\nOntology: {ontology}\n\n"
456
- return formatted.strip()
457
-
458
- async def handle_action(
459
- action: str,
460
- txt: Optional[str],
461
- file_up: Optional[gr.File],
462
- translation_opt: Optional[str],
463
- query_str: Optional[str],
464
- nct_id: Optional[str],
465
- report_fn: Optional[str],
466
- exp_fmt: Optional[str]
467
- ) -> Tuple[Optional[str], Optional[Any], Optional[Any], Optional[str]]:
468
- """
469
- Master function to handle user actions.
470
- Returns a 4-tuple mapped to (output_text, output_chart, output_chart2, output_file).
471
- """
472
- try:
473
- combined_text = txt.strip() if txt else ""
474
-
475
- # 1) If user uploaded a file, parse text from it
476
- if file_up is not None:
477
- ext = os.path.splitext(file_up.name)[1].lower()
478
- if ext == ".txt":
479
- parsed_text = parse_text_file_as_str(file_up)
480
- combined_text += "\n" + parsed_text
481
- elif ext == ".pdf":
482
- parsed_text = parse_pdf_file_as_str(file_up)
483
- combined_text += "\n" + parsed_text
484
- elif ext == ".csv":
485
- df_csv = parse_csv_file_to_df(file_up)
486
- combined_text += "\n" + df_csv.to_csv(index=False)
487
- elif ext in [".xls", ".xlsx"]:
488
- df_xl = parse_excel_file_to_df(file_up)
489
- combined_text += "\n" + df_xl.to_csv(index=False)
490
- else:
491
- return "Unsupported file format.", None, None, None
492
-
493
- # 2) Branch by action
494
- if action == "Summarize":
495
- summary = summarize_text(combined_text)
496
- return summary, None, None, None
497
-
498
- elif action == "Predict Outcome":
499
- preds = predict_outcome(combined_text)
500
- if isinstance(preds, dict):
501
- chart = visualize_predictions(preds)
502
- return json.dumps(preds, indent=2), chart, None, None
503
- return preds, None, None, None
504
-
505
- elif action == "Generate Report":
506
- path = generate_report(combined_text, report_fn or "clinical_report.pdf")
507
- msg = f"Report generated: {path}" if path else "Report generation failed."
508
- return msg, None, None, path
509
-
510
- elif action == "Translate":
511
- translated = translate_text(combined_text, translation_opt or "English to French")
512
- return translated, None, None, None
513
-
514
- elif action == "Perform Named Entity Recognition":
515
- ner_result = perform_named_entity_recognition(combined_text)
516
- return ner_result, None, None, None
517
-
518
- elif action == "Fetch Clinical Studies":
519
- if nct_id:
520
- result = await fetch_articles_by_nct_id(nct_id)
521
- elif query_str:
522
- result = await fetch_articles_by_query(query_str)
523
- else:
524
- return "Provide either an NCT ID or valid query parameters.", None, None, None
525
-
526
- articles = result.get("resultList", {}).get("result", [])
527
- if not articles:
528
- return "No articles found.", None, None, None
529
-
530
- formatted = format_articles(articles)
531
- return formatted, None, None, None
532
-
533
- elif action in ["Fetch PubMed Articles (Legacy)", "Fetch PubMed by Query"]:
534
- pubmed_result = await fetch_pubmed_by_query(query_str or "")
535
- xml_data = pubmed_result.get("result")
536
- if xml_data:
537
- articles = parse_pubmed_xml(xml_data)
538
- if not articles:
539
- return "No articles found.", None, None, None
540
- formatted = "\n\n".join(
541
- f"{a['Title']} - {a['Journal']} ({a['PublicationDate']})"
542
- for a in articles if a['Title']
543
- )
544
- return formatted if formatted else "No articles found.", None, None, None
545
- return "No articles found or error in fetching PubMed data.", None, None, None
546
-
547
- elif action == "Fetch Crossref by Query":
548
- crossref_result = await fetch_crossref_by_query(query_str or "")
549
- items = crossref_result.get("message", {}).get("items", [])
550
- if not items:
551
- return "No results found.", None, None, None
552
- crossref_formatted = "\n\n".join(
553
- f"Title: {it.get('title', ['No title'])[0]}, DOI: {it.get('DOI')}"
554
- for it in items
555
- )
556
- return crossref_formatted, None, None, None
557
-
558
- elif action == "Fetch BioPortal by Query":
559
- bp_result = await fetch_bioportal_by_query(query_str or "")
560
- collection = bp_result.get("collection", [])
561
- if not collection:
562
- return "No BioPortal results found.", None, None, None
563
- formatted = format_bioportal_results(collection)
564
- return formatted, None, None, None
565
-
566
- elif action == "Web Search Explanation":
567
- explanation = fetch_web_search(combined_text)
568
- return explanation, None, None, None
569
-
570
- else:
571
- return "Invalid action selected.", None, None, None
572
-
573
- except Exception as ex:
574
- # Catch all exceptions, log, and return traceback to 'output_text'
575
- tb_str = traceback.format_exc()
576
- logger.error(f"Exception in handle_action:\n{tb_str}")
577
- return f"Traceback:\n{tb_str}", None, None, None
578
-
579
- ###############################################################################
580
- # 9) BUILDING THE GRADIO APP #
581
- ###############################################################################
582
-
583
- with gr.Blocks(css="""
584
- .gradio-container {
585
- background-color: #f5f5f5;
586
- }
587
- .gr-button-primary {
588
- background-color: #4CAF50;
589
- }
590
- .gradio-tabs {
591
- background-color: #ffffff;
592
- }
593
- """) as demo:
594
- gr.Markdown("# 🏥 **AI-Driven Clinical Assistant**")
595
  gr.Markdown("""
596
- **Highlights**:
597
- - **Summarize** clinical text (OpenAI GPT-3.5)
598
- - **Predict** outcomes (Hugging Face fine-tuned model)
599
- - **Translate** (English French)
600
- - **Named Entity Recognition** (spaCy)
601
- - **Fetch** from PubMed, Crossref, Europe PMC, and **BioPortal**
602
- - **Generate** professional PDF reports
603
- - **Web Search Explanations** (OpenAI)
604
-
605
- *Disclaimer*: This is a research demo, **not** a medical device.
606
- """)
607
-
608
- with gr.Row():
609
- text_input = gr.Textbox(
610
- label="Input Clinical Text",
611
- lines=5,
612
- placeholder="Enter clinical text, research notes, or queries...",
613
- interactive=True
614
- )
615
- file_input = gr.File(
616
- label="Upload File",
617
- file_types=[".txt", ".csv", ".xls", ".xlsx", ".pdf"],
618
- interactive=True
619
- )
620
-
621
  action = gr.Radio(
622
  [
623
  "Summarize",
624
- "Predict Outcome",
625
  "Generate Report",
626
- "Translate",
627
- "Perform Named Entity Recognition",
628
- "Fetch Clinical Studies",
629
- "Fetch PubMed Articles (Legacy)",
630
- "Fetch PubMed by Query",
631
- "Fetch Crossref by Query",
632
- "Fetch BioPortal by Query",
633
- "Web Search Explanation"
634
  ],
635
  label="Select an Action",
636
- interactive=True
637
- )
638
-
639
- translation_option = gr.Dropdown(
640
- choices=list(LANGUAGE_MAP.keys()),
641
- label="Translation Option",
642
- value="English to French",
643
- interactive=True
644
- )
645
-
646
- query_params_input = gr.Textbox(
647
- label="Query Parameters (JSON)",
648
- placeholder='{"term": "cancer"} or {"q": "cancer"} for BioPortal',
649
- interactive=True
650
  )
651
-
652
- nct_id_input = gr.Textbox(
653
- label="NCT ID",
654
- placeholder="Enter NCT ID (e.g., NCT00000000)",
655
- interactive=True
656
- )
657
-
658
- report_filename_input = gr.Textbox(
659
- label="Report Filename",
660
- value="clinical_report.pdf",
661
- interactive=True
662
- )
663
-
664
- exp_fmt = gr.Dropdown(
665
- choices=["None", "CSV", "JSON"],
666
- label="Export Format",
667
- value="None",
668
- interactive=True
669
- )
670
-
671
- # Outputs
672
- output_text = gr.Textbox(
673
- label="Output",
674
- lines=20,
675
- interactive=False
676
- )
677
-
678
- with gr.Row():
679
- output_chart = gr.Plot(label="Prediction Probabilities")
680
- output_chart2 = gr.Plot(label="Additional Visualization") # Placeholder for future use
681
-
682
- output_file = gr.File(label="Generated File", interactive=False)
683
-
684
- submit_btn = gr.Button("Submit", variant="primary")
685
-
686
- gr.Markdown("""
687
- ---
688
-
689
- ### **Important Disclaimers**
690
-
691
- - **Not a Medical Device**: This tool is not intended to provide clinical diagnoses or final medical decisions. Always consult qualified healthcare professionals for clinical decisions.
692
- - **AI/ML Limitations**: GPT-based summaries and classification models offer powerful insights but may generate incomplete or inaccurate results. Always verify AI-generated content.
693
- - **Credential Security**: Ensure the security of your API keys (`OPENAI_API_KEY`, `HF_TOKEN`, `BIOPORTAL_API_KEY`) to safely access external services.
694
- - **Data Privacy**: If handling real patient data, ensure compliance with applicable data protection regulations (e.g., HIPAA, GDPR).
695
-
696
- ---
697
- """)
698
-
699
- # Connect the submit button to the action handler
700
  submit_btn.click(
701
- fn=lambda action, txt, file_up, trans_opt, query, nct_id, report_fn, exp_fm: asyncio.run(
702
- handle_action(action, txt, file_up, trans_opt, query, nct_id, report_fn, exp_fm)
703
- ),
704
- inputs=[action, text_input, file_input, translation_option, query_params_input, nct_id_input, report_filename_input, exp_fmt],
705
- outputs=[output_text, output_chart, output_chart2, output_file],
706
  )
707
 
708
- ###############################################################################
709
- # 10) LAUNCHING THE GRADIO APP #
710
- ###############################################################################
711
-
712
  # Launch the Gradio interface
713
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
12
  from dotenv import load_dotenv
13
  from loguru import logger
14
  from huggingface_hub import login
15
+ from openai import OpenAI
16
  from reportlab.pdfgen import canvas
17
  from transformers import (
18
  AutoTokenizer,
 
30
  # 1) ENVIRONMENT & LOGGING #
31
  ###############################################################################
32
 
33
+ # Ensure spaCy model is downloaded (English Core Web)
34
+ try:
35
+ nlp = spacy.load("en_core_web_sm")
36
+ except OSError:
37
+ logger.info("Downloading SpaCy 'en_core_web_sm' model...")
38
+ spacy.cli.download("en_core_web_sm")
39
+ nlp = spacy.load("en_core_web_sm")
40
+
41
+ # Logging
42
  logger.add("error_logs.log", rotation="1 MB", level="ERROR")
43
 
44
  # Load environment variables
45
  load_dotenv()
46
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
47
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
48
+ BIOPORTAL_API_KEY = os.getenv("BIOPORTAL_API_KEY") # For BioPortal integration
49
  ENTREZ_EMAIL = os.getenv("ENTREZ_EMAIL")
50
 
 
51
  if not HUGGINGFACE_TOKEN or not OPENAI_API_KEY:
52
  logger.error("Missing Hugging Face or OpenAI credentials.")
53
  raise ValueError("Missing credentials for Hugging Face or OpenAI.")
 
59
  # Hugging Face login
60
  login(HUGGINGFACE_TOKEN)
61
 
62
+ # OpenAI
63
+ client = OpenAI(api_key=OPENAI_API_KEY)
64
 
65
+ # Device: CPU or GPU
66
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
  logger.info(f"Using device: {device}")
68
 
 
 
 
 
 
 
 
 
69
  ###############################################################################
70
  # 2) HUGGING FACE & TRANSLATION MODEL SETUP #
71
  ###############################################################################
72
 
73
+ MODEL_NAME = "mgbam/bert-base-finetuned-mgbam"
 
74
  try:
75
+ model = AutoModelForSequenceClassification.from_pretrained(
76
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
77
  ).to(device)
78
+ tokenizer = AutoTokenizer.from_pretrained(
79
+ MODEL_NAME, use_auth_token=HUGGINGFACE_TOKEN
80
  )
81
  except Exception as e:
82
+ logger.error(f"Model load error: {e}")
83
  raise
84
 
 
 
85
  try:
86
+ translation_model_name = "Helsinki-NLP/opus-mt-en-fr"
87
  translation_model = MarianMTModel.from_pretrained(
88
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
89
  ).to(device)
90
  translation_tokenizer = MarianTokenizer.from_pretrained(
91
+ translation_model_name, use_auth_token=HUGGINGFACE_TOKEN
92
  )
93
  except Exception as e:
94
  logger.error(f"Translation model load error: {e}")
95
  raise
96
 
97
+ # Language map for translation
98
  LANGUAGE_MAP: Dict[str, Tuple[str, str]] = {
99
  "English to French": ("en", "fr"),
100
  "French to English": ("fr", "en"),
 
150
  })
151
  return articles
152
 
153
+ def explain_clinical_results(results: str) -> str:
154
+ """Generate a clinical explanation from raw results."""
155
+ try:
156
+ response = client.chat.completions.create(
157
+ model="gpt-3.5-turbo",
158
+ messages=[{"role": "user", "content": f"Explain the clinical test results:\n{results}"}],
159
+ max_tokens=500,
160
+ temperature=0.7,
161
+ )
162
+ return response.choices[0].message.content.strip()
163
+ except Exception as e:
164
+ logger.error(f"Explanation error: {e}")
165
+ return "Failed to generate explanation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  ###############################################################################
168
  # 6) CORE FUNCTIONS #
169
  ###############################################################################
170
 
171
  def summarize_text(text: str) -> str:
172
+ """OpenAI GPT-3.5 summarization."""
173
  if not text.strip():
174
  return "No text provided for summarization."
175
  try:
176
+ response = client.chat.completions.create(
177
  model="gpt-3.5-turbo",
178
  messages=[{"role": "user", "content": f"Summarize this clinical data:\n{text}"}],
179
+ max_tokens=200,
180
  temperature=0.7,
181
  )
182
  return response.choices[0].message.content.strip()
 
184
  logger.error(f"Summarization error: {e}")
185
  return "Summarization failed."
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def generate_report(text: str, filename: str = "clinical_report.pdf") -> Optional[str]:
188
+ """Generate a professional PDF report from the text."""
189
  try:
190
  if not text.strip():
191
  logger.warning("No text provided for the report.")
192
  c = canvas.Canvas(filename)
193
+ c.drawString(100, 750, "Clinical Research Report")
 
 
194
  lines = text.split("\n")
195
+ y = 730
196
  for line in lines:
197
  if y < 50:
198
  c.showPage()
199
+ y = 750
 
200
  c.drawString(100, y, line)
201
  y -= 15
202
  c.save()
 
207
  return None
208
 
209
  def visualize_predictions(predictions: Dict[str, float]) -> alt.Chart:
210
+ """Simple Altair bar chart to visualize classification probabilities."""
211
  data = pd.DataFrame(list(predictions.items()), columns=["Label", "Probability"])
212
  chart = (
213
  alt.Chart(data)
 
221
  )
222
  return chart
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  ###############################################################################
225
+ # 7) BUILDING THE GRADIO APP #
226
  ###############################################################################
227
 
228
+ with gr.Blocks() as demo:
229
+ gr.Markdown("# 🏥 AI-Driven Clinical Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  gr.Markdown("""
231
+ **Highlights**:
232
+ - **Summarize** clinical text (OpenAI GPT-3.5)
233
+ - **Explain** clinical test results and trial outcomes
234
+ - **Generate** professional PDF reports
235
+ """)
236
+
237
+ text_input = gr.Textbox(label="Input Text", lines=5, placeholder="Enter clinical text or test results...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  action = gr.Radio(
239
  [
240
  "Summarize",
241
+ "Explain Clinical Results",
242
  "Generate Report",
 
 
 
 
 
 
 
 
243
  ],
244
  label="Select an Action",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  )
246
+
247
+ output_text = gr.Textbox(label="Output", lines=8)
248
+ output_file = gr.File(label="Generated File")
249
+
250
+ submit_btn = gr.Button("Submit")
251
+
252
+ async def handle_action(
253
+ action: str,
254
+ txt: str,
255
+ report_fn: str
256
+ ) -> Tuple[Optional[str], Optional[str]]:
257
+ """Handle clinical actions based on the user's selection."""
258
+ try:
259
+ combined_text = txt.strip()
260
+
261
+ if action == "Summarize":
262
+ summary = summarize_text(combined_text)
263
+ return summary, None
264
+
265
+ elif action == "Explain Clinical Results":
266
+ explanation = explain_clinical_results(combined_text)
267
+ return explanation, None
268
+
269
+ elif action == "Generate Report":
270
+ path = generate_report(combined_text, report_fn)
271
+ msg = f"Report generated: {path}" if path else "Report generation failed."
272
+ return msg, path
273
+
274
+ return "Invalid action.", None
275
+ except Exception as e:
276
+ logger.error(f"Exception: {e}")
277
+ return f"Error: {str(e)}", None
278
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  submit_btn.click(
280
+ fn=handle_action,
281
+ inputs=[action, text_input, report_filename_input],
282
+ outputs=[output_text, output_file],
 
 
283
  )
284
 
 
 
 
 
285
  # Launch the Gradio interface
286
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)