|
import os |
|
import ast |
|
import spaces |
|
import gradio as gr |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
from huggingface_hub import login as hf_login |
|
|
|
import xgrammar as xgr |
|
from pydantic import BaseModel |
|
|
|
hf_login(token=os.getenv("HF_TOKEN")) |
|
|
|
model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization-FP32" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
attn_implementation='eager', |
|
trust_remote_code=True, |
|
) |
|
|
|
class Person(BaseModel): |
|
life_style: str |
|
family_history: str |
|
social_history: str |
|
medical_surgical_history: str |
|
signs_symptoms: str |
|
comorbidities: str |
|
diagnostic_techniques_procedures: str |
|
diagnosis: str |
|
laboratory_values: str |
|
pathology: str |
|
pharmacological_therapy: str |
|
interventional_therapy: str |
|
patient_outcome_assessment: str |
|
age: str |
|
gender: str |
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface( |
|
tokenizer, vocab_size=len(tokenizer) |
|
) |
|
|
|
grammar_compiler = xgr.GrammarCompiler(tokenizer_info) |
|
compiled_grammar = grammar_compiler.compile_json_schema(Person) |
|
|
|
default_value = "A 57-year-old male presented with fever (38.9°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 × 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 × 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan." |
|
|
|
prompt = """You are a text extraction system for clinical reports. |
|
Please extract relevant clinical information from the report. |
|
|
|
### Instructions |
|
|
|
- Use the JSON Schema given below. |
|
- Return only a valid JSON object – no markdown, no comments. |
|
- If no relevant facts are given for a field, set its value to "N/A". |
|
- If multile relevant facts are given for a field, separate them with "; ". |
|
|
|
### JSON Schema |
|
|
|
{ |
|
'life_style': '', |
|
'family_history': '', |
|
'social_history': '', |
|
'medical_surgical_history': '', |
|
'signs_symptoms': '', |
|
'comorbidities': '', |
|
'diagnostic_techniques_procedures': '', |
|
'diagnosis': '', |
|
'laboratory_values': '', |
|
'pathology': '', |
|
'pharmacological_therapy': '', |
|
'interventional_therapy': '', |
|
'patient_outcome_assessment': '', |
|
'age': '', |
|
'gender': '', |
|
} |
|
|
|
### Clinical Report |
|
""" |
|
|
|
def generate_html_tables(data, selected_fields): |
|
key_label_map = { |
|
'age': 'Age', |
|
'gender': 'Gender', |
|
'life_style': 'Lifestyle', |
|
'social_history': 'Social Background', |
|
'medical_surgical_history': 'Personal', |
|
'family_history': 'Family Members', |
|
'signs_symptoms': 'Symptoms', |
|
'comorbidities': 'Comorbid Conditions', |
|
'diagnostic_techniques_procedures': 'Diagnostic Procedures', |
|
'laboratory_values': 'Laboratory Results', |
|
'pathology': 'Pathology Report', |
|
'diagnosis': 'Diagnosis', |
|
'interventional_therapy': 'Interventional Therapy', |
|
'pharmacological_therapy': 'Pharmacological Therapy', |
|
'patient_outcome_assessment': 'Patient Outcome', |
|
} |
|
|
|
label_key_map = {v: k for k, v in key_label_map.items()} |
|
|
|
categories = { |
|
"Personal Information": ["Age", "Gender", "Lifestyle", "Social Background"], |
|
"Medical History": ["Personal", "Family Members"], |
|
"Clinical Presentation": ["Symptoms", "Comorbid Conditions"], |
|
"Medical Assessment": ["Diagnostic Procedures", "Laboratory Results", "Pathology Report"], |
|
"Diagnosis": ["Diagnosis"], |
|
"Treatment": ["Interventional Therapy", "Pharmacological Therapy"], |
|
"Patient Outcome": ["Patient Outcome"], |
|
} |
|
|
|
def format_bullets(value): |
|
items = [item.strip() for item in value.split(";") if item.strip()] |
|
if not items: |
|
return "<i>Not Available</i>" |
|
if len(items) == 1: |
|
return items[0] |
|
return "<ul style='margin: 0; padding-left: 1em'>" + "".join(f"<li>{item}</li>" for item in items) + "</ul>" |
|
|
|
table_style = ( |
|
"width: 100%;" |
|
"height: 100%;" |
|
"table-layout: fixed;" |
|
) |
|
|
|
th_td_style = ( |
|
"padding: 8px;" |
|
"border: 1px solid #ccc;" |
|
"vertical-align: top;" |
|
"text-align: left;" |
|
) |
|
|
|
html_tables = [] |
|
for section, labels in categories.items(): |
|
section_fields = [label for label in labels if label in selected_fields] |
|
if section_fields: |
|
table_html = f"<h3 style='margin-bottom: 0.5em;'>{section}</h3>" |
|
table_html += f"<table style='{table_style}'>" |
|
table_html += f"<tr><th style='height: 30px; {th_td_style}; width: 150px;'>Field</th><th style='height: 30px; {th_td_style};'>Details</th></tr>" |
|
for label in section_fields: |
|
key = label_key_map[label] |
|
value = data.get(key, "N/A") |
|
details = "<i>Not Available</i>" if value == "N/A" else format_bullets(value) |
|
table_html += f"<tr><td style='{th_td_style}; width: 150px;'><b>{label}</b></td><td style='{th_td_style}'>{details}</td></tr>" |
|
table_html += "</table>" |
|
html_tables.append(table_html) |
|
|
|
i = 0 |
|
grouped_html = "" |
|
while i < len(html_tables): |
|
num_per_row = 2 if i < 4 else 3 |
|
row_tables = html_tables[i:i+num_per_row] |
|
grouped_html += ( |
|
"<div style='display: flex; gap: 1em; margin-bottom: 2em;'>" |
|
) |
|
for table in row_tables: |
|
grouped_html += ( |
|
"<div style='display: flex; flex-direction: column;'>" |
|
f"{table}" |
|
"</div>" |
|
) |
|
grouped_html += "</div>" |
|
i += num_per_row |
|
|
|
return f"<div style='font-family: sans-serif;'>{grouped_html}</div>" |
|
|
|
@spaces.GPU(duration=60) |
|
def summarize( |
|
text, |
|
personal_info, |
|
medical_history, |
|
clinical_presentation, |
|
medical_assessment, |
|
diagnosis, |
|
treatment, |
|
patient_outcome, |
|
): |
|
|
|
if not text.strip(): |
|
return "Please enter some text to summarize." |
|
|
|
if text == default_value: |
|
response = ['{"life_style": "N/A", "family_history": "N/A", "social_history": "N/A", "medical_surgical_history": "N/A", "signs_symptoms": "Fever; Chest pain; Cough; Progressive dyspnea; Tachypnea; Tachycardia; Decreased breath sounds in both lung bases; Crackles on the left", "comorbidities": "N/A", "diagnostic_techniques_procedures": "Chest X-ray; Echocardiography; Thoracentesis; Laboratory tests; Pleural fluid analysis; Urinary pneumococcal antigen test; Pleural fluid culture", "diagnosis": "Pneumonia; Pericardial effusion; S. pneumoniae infection", "laboratory_values": "White blood cell count: 11.78 \\u00d7 10^9 cells/L (84.3% neutrophils, 4.3% lymphocytes, 9.1% monocytes); Platelet count: 512 \\u00d7 10^9/L; Serum C-reactive protein: 31.27 mg/dL; Serum creatinine: 0.94 mg/dL; Serum sodium: 133 mEq/L; Serum potassium: 3.72 mEq/L; Pleural fluid pH: 7.16; Pleural fluid glucose: 4.5 mg/dL; Pleural fluid proteins: 49.1 g/L; Pleural fluid LDH: 1,385 U/L", "pathology": "N/A", "pharmacological_therapy": "Amoxicillin-clavulanate (2.2 g/8 h, i.v.); Levofloxacin (500 mg twice a day); Ibuprofen (800 mg/day)", "interventional_therapy": "Pericardiocentesis; Thoracentesis", "patient_outcome_assessment": "Nearly complete resolution of alterations on chest X-ray and CT scan", "age": "57 year", "gender": "Male"}'] |
|
else: |
|
messages = [ |
|
{"role": "system", "content": prompt.strip()}, |
|
{"role": "user", "content": text.strip()}, |
|
] |
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=False, |
|
) |
|
|
|
|
|
|
|
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) |
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
generated_ids = model.generate( |
|
input_ids=model_inputs["input_ids"], |
|
attention_mask = model_inputs["attention_mask"], |
|
max_new_tokens=2048, |
|
logits_processor=[xgr_logits_processor] |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
try: |
|
data = ast.literal_eval(response[0]) |
|
except: |
|
data = { |
|
'life_style': 'N/A', |
|
'family_history': 'N/A', |
|
'social_history': 'N/A', |
|
'medical_surgical_history': 'N/A', |
|
'signs_symptoms': 'N/A', |
|
'comorbidities': 'N/A', |
|
'diagnostic_techniques_procedures': 'N/A', |
|
'diagnosis': 'N/A', |
|
'laboratory_values': 'N/A', |
|
'pathology': '', |
|
'pharmacological_therapy': 'N/A', |
|
'interventional_therapy': 'N/A', |
|
'patient_outcome_assessment': 'N/A', |
|
'age': 'N/A', |
|
'gender': 'N/A', |
|
} |
|
|
|
selected_fields = [] |
|
selected_fields += personal_info |
|
selected_fields += medical_history |
|
selected_fields += clinical_presentation |
|
selected_fields += medical_assessment |
|
selected_fields += diagnosis |
|
selected_fields += treatment |
|
selected_fields += patient_outcome |
|
|
|
return generate_html_tables(data, selected_fields) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
gr.HTML(""" |
|
<div align="center"> |
|
<img src="https://huggingface.co/spaces/gregorlied/medical-text-summarization/resolve/main/assets/LlamaMD-logo.png" alt="LlamaMD Logo" width="120" style="margin-bottom: 10px;"> |
|
<h2><strong>LlamaMD</strong></h2> |
|
<p><em>Structured Information Extraction from Clinical Reports</em></p> |
|
</div> |
|
""") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("LLamaMD"): |
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
label="Clinical Report", |
|
autoscroll=False, |
|
lines=15, |
|
max_lines=15, |
|
placeholder="Paste your clinical report here...", |
|
value=default_value, |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
personal_info = gr.CheckboxGroup( |
|
label="Personal Information", |
|
choices=[ |
|
"Age", |
|
"Gender", |
|
"Lifestyle", |
|
"Social Background", |
|
], |
|
value=[ |
|
"Age", |
|
"Gender", |
|
"Lifestyle", |
|
"Social Background", |
|
], |
|
) |
|
|
|
with gr.Column(): |
|
medical_history = gr.CheckboxGroup( |
|
label="Medical History", |
|
choices=[ |
|
"Personal", |
|
"Family Members", |
|
], |
|
value=[ |
|
"Personal", |
|
"Family Members", |
|
], |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
clinical_presentation = gr.CheckboxGroup( |
|
label="Clinical Presentation", |
|
choices=[ |
|
"Symptoms", |
|
"Comorbid Conditions", |
|
], |
|
value=[ |
|
"Symptoms", |
|
"Comorbid Conditions", |
|
], |
|
) |
|
with gr.Column(): |
|
medical_assessment = gr.CheckboxGroup( |
|
label="Medical Assessment", |
|
choices=[ |
|
"Diagnostic Procedures", |
|
"Laboratory Results", |
|
"Pathology Report", |
|
], |
|
value=[ |
|
"Diagnostic Procedures", |
|
"Laboratory Results", |
|
"Pathology Report", |
|
], |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
diagnosis = gr.CheckboxGroup( |
|
label="Diagnosis", |
|
choices=[ |
|
"Diagnosis", |
|
], |
|
value=[ |
|
"Diagnosis", |
|
], |
|
) |
|
with gr.Column(): |
|
treatment = gr.CheckboxGroup( |
|
label="Treatment", |
|
choices=[ |
|
"Interventional Therapy", |
|
"Pharmacological Therapy", |
|
], |
|
value=[ |
|
"Interventional Therapy", |
|
"Pharmacological Therapy", |
|
], |
|
) |
|
with gr.Column(): |
|
patient_outcome = gr.CheckboxGroup( |
|
label="Patient Outcome", |
|
choices=[ |
|
"Patient Outcome", |
|
], |
|
value=[ |
|
"Patient Outcome", |
|
], |
|
) |
|
|
|
with gr.Row(): |
|
summarize_btn = gr.Button("Extract") |
|
|
|
with gr.Row(): |
|
output_text = gr.HTML() |
|
|
|
summarize_btn.click( |
|
fn=summarize, |
|
inputs=[input_text, personal_info, medical_history, clinical_presentation, medical_assessment, diagnosis, treatment, patient_outcome], |
|
outputs=output_text, |
|
show_progress=True, |
|
) |
|
|
|
with gr.Tab("Help"): |
|
gr.Markdown("""## Help |
|
|
|
### Personal Information |
|
|
|
**Age**: Age of the patient.<br> |
|
**Gender**: Gender of the patient.<br> |
|
**Lifestyle**: Daily habits and activities of the patient (e.g. alcohol consumption, diet, smoking status).<br> |
|
**Social Background**: Social factors of the patient (e.g. housing situation, marital status).<br> |
|
|
|
### Medical History |
|
|
|
**Personal**: Past medical conditions, previous surgeries or treatments of the patient.<br> |
|
**Family Members**: Relevant medical conditions or genetic disorders in the patient’s family (e.g. cancer, heart disease).<br> |
|
|
|
### Clinical Presentation |
|
|
|
**Symptoms**: Current symptoms of the patient.<br> |
|
**Comorbid Conditions**: Other medical conditions of the patient that may influence the treatment.<br> |
|
|
|
### Medical Assessment |
|
|
|
**Diagnostic Procedures**: Description of the diagnostic tests or procedures performed (e.g. X-rays, MRIs)<br> |
|
**Laboratory Results**: Results foom laboratory test (e.g. blood counts, electrolyte levels)<br> |
|
**Pathology Report**: Findings from pathological examinations (e.g. biopsy results)<br> |
|
|
|
### Diagnosis |
|
|
|
**Diagnosis**: All levels of diagnosis mentioned in the report.<br> |
|
|
|
### Treatment |
|
|
|
**Interventional Therapy**: Medications prescribed to the patient.<br> |
|
**Pharmacological Therapy**: Information on surgical or non-surgical interventions performed.<br> |
|
|
|
### Patient Outcome |
|
|
|
**Patient Outcome**: Evaluation of the patient’s health status at the end of treatment.<br> |
|
""") |
|
|
|
with gr.Tab("About"): |
|
gr.Markdown("""## About |
|
|
|
LlamaMD is a project developed as part of the "NLP for Social Good" course at TU Berlin. |
|
|
|
The goal of this project is to perform structured information extraction from clinical reports, helping doctors to have more time for their patients. |
|
|
|
The system is based on `meta-llama/Llama-3.2-1B-Instruct`, which has been fine-tuned on the ELMTEX dataset. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |