Spaces:
Runtime error
Runtime error
import os | |
import ast | |
import spaces | |
import gradio as gr | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login as hf_login | |
import xgrammar as xgr | |
from pydantic import BaseModel | |
hf_login(token=os.getenv("HF_TOKEN")) | |
model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization-FP32" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
attn_implementation='eager', | |
trust_remote_code=True, | |
) | |
class Person(BaseModel): | |
life_style: str | |
family_history: str | |
social_history: str | |
medical_surgical_history: str | |
signs_symptoms: str | |
comorbidities: str | |
diagnostic_techniques_procedures: str | |
diagnosis: str | |
laboratory_values: str | |
pathology: str | |
pharmacological_therapy: str | |
interventional_therapy: str | |
patient_outcome_assessment: str | |
age: str | |
gender: str | |
config = AutoConfig.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer_info = xgr.TokenizerInfo.from_huggingface( | |
tokenizer, vocab_size=len(tokenizer) | |
) | |
grammar_compiler = xgr.GrammarCompiler(tokenizer_info) | |
compiled_grammar = grammar_compiler.compile_json_schema(Person) | |
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) | |
default_value = "A 57-year-old male presented with fever (38.9Β°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 Γ 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 Γ 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan." | |
prompt = """You are a text extraction system for clinical reports. | |
Please extract relevant clinical information from the report. | |
### Instructions | |
- Use the JSON Schema given below. | |
- Return only a valid JSON object β no markdown, no comments. | |
- If no relevant facts are given for a field, set its value to "N/A". | |
- If multile relevant facts are given for a field, separate them with "; ". | |
### JSON Schema | |
{ | |
'life_style': '', | |
'family_history': '', | |
'social_history': '', | |
'medical_surgical_history': '', | |
'signs_symptoms': '', | |
'comorbidities': '', | |
'diagnostic_techniques_procedures': '', | |
'diagnosis': '', | |
'laboratory_values': '', | |
'pathology': '', | |
'pharmacological_therapy': '', | |
'interventional_therapy': '', | |
'patient_outcome_assessment': '', | |
'age': '', | |
'gender': '', | |
} | |
### Clinical Report | |
""" | |
def generate_html_tables(data, selected_fields): | |
key_label_map = { | |
'age': 'Age', | |
'gender': 'Gender', | |
'life_style': 'Lifestyle', | |
'social_history': 'Social Background', | |
'medical_surgical_history': 'Personal', | |
'family_history': 'Family Members', | |
'signs_symptoms': 'Symptoms', | |
'comorbidities': 'Comorbid Conditions', | |
'diagnostic_techniques_procedures': 'Diagnostic Procedures', | |
'laboratory_values': 'Laboratory Results', | |
'pathology': 'Pathology Report', | |
'diagnosis': 'Diagnosis', | |
'interventional_therapy': 'Interventional Therapy', | |
'pharmacological_therapy': 'Pharmacological Therapy', | |
'patient_outcome_assessment': 'Patient Outcome', | |
} | |
label_key_map = {v: k for k, v in key_label_map.items()} | |
categories = { | |
"Personal Information": ["Age", "Gender", "Lifestyle", "Social Background"], | |
"Medical History": ["Personal", "Family Members"], | |
"Clinical Presentation": ["Symptoms", "Comorbid Conditions"], | |
"Medical Assessment": ["Diagnostic Procedures", "Laboratory Results", "Pathology Report"], | |
"Diagnosis": ["Diagnosis"], | |
"Treatment": ["Interventional Therapy", "Pharmacological Therapy"], | |
"Patient Outcome": ["Patient Outcome"], | |
} | |
def format_bullets(value): | |
items = [item.strip() for item in value.split(";") if item.strip()] | |
if not items: | |
return "<i>Not Available</i>" | |
if len(items) == 1: | |
return items[0] | |
return "<ul style='margin: 0; padding-left: 1.2em'>" + "".join(f"<li>{item}</li>" for item in items) + "</ul>" | |
table_style = ( | |
"width: 100%;" | |
"table-layout: fixed;" | |
"border-collapse: collapse;" | |
"word-wrap: break-word;" | |
"height: 100%;" | |
) | |
th_td_style_first = ( | |
"padding: 8px;" | |
"border: 1px solid #ccc;" | |
"vertical-align: top;" | |
"text-align: left;" | |
"height: 30px;" | |
"overflow: hidden;" | |
) | |
th_td_style_other = ( | |
"padding: 8px;" | |
"border: 1px solid #ccc;" | |
"vertical-align: top;" | |
"text-align: left;" | |
) | |
html_tables = [] | |
for section, labels in categories.items(): | |
section_fields = [label for label in labels if label in selected_fields] | |
if section_fields: | |
table_html = f"<h3 style='margin-bottom: 0.5em;'>{section}</h3>" | |
table_html += f"<table style='{table_style}'>" | |
table_html += f"<tr><th style='{th_td_style_first}; width: 150px;'>Field</th><th style='{th_td_style_first};'>Details</th></tr>" | |
for label in section_fields: | |
key = label_key_map[label] | |
value = data.get(key, "N/A") | |
details = "<i>Not Available</i>" if value == "N/A" else format_bullets(value) | |
table_html += f"<tr><td style='{th_td_style_other}; width: 150px;'><b>{label}</b></td><td style='{th_td_style_other}'>{details}</td></tr>" | |
table_html += "</table>" | |
html_tables.append(table_html) | |
i = 0 | |
grouped_html = "" | |
while i < len(html_tables): | |
num_per_row = 2 if i < 4 else 3 | |
row_tables = html_tables[i:i+num_per_row] | |
grouped_html += ( | |
"<div style='display: flex; gap: 1em; margin-bottom: 2em; align-items: stretch;'>" | |
) | |
for table in row_tables: | |
grouped_html += ( | |
"<div style='flex: 1; min-width: 0; display: flex; flex-direction: column;'>" | |
f"{table}" | |
"</div>" | |
) | |
grouped_html += "</div>" | |
i += num_per_row | |
return f"<div style='font-family: sans-serif;'>{grouped_html}</div>" | |
def summarize( | |
text, | |
personal_info, | |
medical_history, | |
clinical_presentation, | |
medical_assessment, | |
diagnosis, | |
treatment, | |
patient_outcome, | |
): | |
if not text.strip(): | |
return "Please enter some text to summarize." | |
if text == default_value: | |
response = ['{"life_style": "N/A", "family_history": "N/A", "social_history": "N/A", "medical_surgical_history": "N/A", "signs_symptoms": "Fever; Chest pain; Cough; Progressive dyspnea; Tachypnea; Tachycardia; Decreased breath sounds in both lung bases; Crackles on the left", "comorbidities": "N/A", "diagnostic_techniques_procedures": "Chest X-ray; Echocardiography; Thoracentesis; Laboratory tests; Pleural fluid analysis; Urinary pneumococcal antigen test; Pleural fluid culture", "diagnosis": "Pneumonia; Pericardial effusion; S. pneumoniae infection", "laboratory_values": "White blood cell count: 11.78 \\u00d7 10^9 cells/L (84.3% neutrophils, 4.3% lymphocytes, 9.1% monocytes); Platelet count: 512 \\u00d7 10^9/L; Serum C-reactive protein: 31.27 mg/dL; Serum creatinine: 0.94 mg/dL; Serum sodium: 133 mEq/L; Serum potassium: 3.72 mEq/L; Pleural fluid pH: 7.16; Pleural fluid glucose: 4.5 mg/dL; Pleural fluid proteins: 49.1 g/L; Pleural fluid LDH: 1,385 U/L", "pathology": "N/A", "pharmacological_therapy": "Amoxicillin-clavulanate (2.2 g/8 h, i.v.); Levofloxacin (500 mg twice a day); Ibuprofen (800 mg/day)", "interventional_therapy": "Pericardiocentesis; Thoracentesis", "patient_outcome_assessment": "Nearly complete resolution of alterations on chest X-ray and CT scan", "age": "57 year", "gender": "Male"}'] | |
else: | |
messages = [ | |
{"role": "system", "content": prompt.strip()}, | |
{"role": "user", "content": text.strip()}, | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False, # only relevant for qwen | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
input_ids=model_inputs["input_ids"], | |
attention_mask = model_inputs["attention_mask"], | |
max_new_tokens=2048, | |
logits_processor=[xgr_logits_processor] | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
try: | |
data = ast.literal_eval(response[0]) | |
except: | |
data = { | |
'life_style': 'N/A', | |
'family_history': 'N/A', | |
'social_history': 'N/A', | |
'medical_surgical_history': 'N/A', | |
'signs_symptoms': 'N/A', | |
'comorbidities': 'N/A', | |
'diagnostic_techniques_procedures': 'N/A', | |
'diagnosis': 'N/A', | |
'laboratory_values': 'N/A', | |
'pathology': '', | |
'pharmacological_therapy': 'N/A', | |
'interventional_therapy': 'N/A', | |
'patient_outcome_assessment': 'N/A', | |
'age': 'N/A', | |
'gender': 'N/A', | |
} | |
selected_fields = [] | |
selected_fields += personal_info | |
selected_fields += medical_history | |
selected_fields += clinical_presentation | |
selected_fields += medical_assessment | |
selected_fields += diagnosis | |
selected_fields += treatment | |
selected_fields += patient_outcome | |
return generate_html_tables(data, selected_fields) | |
with gr.Blocks() as demo: | |
# need to be combined with `hf_oauth: true` in README.md | |
# button = gr.LoginButton("Sign in") | |
with gr.Column(): | |
gr.HTML(""" | |
<div align="center"> | |
<img src="https://huggingface.co/spaces/gregorlied/medical-text-summarization/resolve/main/assets/LlamaMD-logo.png" alt="LlamaMD Logo" width="120" style="margin-bottom: 10px;"> | |
<h2><strong>LlamaMD</strong></h2> | |
<p><em>Structured Information Extraction from Clinical Reports</em></p> | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.Tab("LLamaMD"): | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Clinical Report", | |
autoscroll=False, | |
lines=15, | |
max_lines=15, | |
placeholder="Paste your clinical report here...", | |
value=default_value, | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
personal_info = gr.CheckboxGroup( | |
label="Personal Information", | |
choices=[ | |
"Age", | |
"Gender", | |
"Lifestyle", | |
"Social Background", | |
], | |
value=[ | |
"Age", | |
"Gender", | |
"Lifestyle", | |
"Social Background", | |
], | |
) | |
with gr.Column(): | |
medical_history = gr.CheckboxGroup( | |
label="Medical History", | |
choices=[ | |
"Personal", | |
"Family Members", | |
], | |
value=[ | |
"Personal", | |
"Family Members", | |
], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
clinical_presentation = gr.CheckboxGroup( | |
label="Clinical Presentation", | |
choices=[ | |
"Symptoms", | |
"Comorbid Conditions", | |
], | |
value=[ | |
"Symptoms", | |
"Comorbid Conditions", | |
], | |
) | |
with gr.Column(): | |
medical_assessment = gr.CheckboxGroup( | |
label="Medical Assessment", | |
choices=[ | |
"Diagnostic Procedures", | |
"Laboratory Results", | |
"Pathology Report", | |
], | |
value=[ | |
"Diagnostic Procedures", | |
"Laboratory Results", | |
"Pathology Report", | |
], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
diagnosis = gr.CheckboxGroup( | |
label="Diagnosis", | |
choices=[ | |
"Diagnosis", | |
], | |
value=[ | |
"Diagnosis", | |
], | |
) | |
with gr.Column(): | |
treatment = gr.CheckboxGroup( | |
label="Treatment", | |
choices=[ | |
"Interventional Therapy", | |
"Pharmacological Therapy", | |
], | |
value=[ | |
"Interventional Therapy", | |
"Pharmacological Therapy", | |
], | |
) | |
with gr.Column(): | |
patient_outcome = gr.CheckboxGroup( | |
label="Patient Outcome", | |
choices=[ | |
"Patient Outcome", | |
], | |
value=[ | |
"Patient Outcome", | |
], | |
) | |
with gr.Row(): | |
summarize_btn = gr.Button("Extract") | |
with gr.Row(): | |
output_text = gr.HTML() | |
summarize_btn.click( | |
fn=summarize, | |
inputs=[input_text, personal_info, medical_history, clinical_presentation, medical_assessment, diagnosis, treatment, patient_outcome], | |
outputs=output_text, | |
show_progress=True, | |
) | |
with gr.Tab("Help"): | |
gr.Markdown("""## Help | |
### Personal Information | |
**Age**: Age of the patient.<br> | |
**Gender**: Gender of the patient.<br> | |
**Lifestyle**: Daily habits and activities of the patient (e.g. alcohol consumption, diet, smoking status).<br> | |
**Social Background**: Social factors of the patient (e.g. housing situation, marital status).<br> | |
### Medical History | |
**Personal**: Past medical conditions, previous surgeries or treatments of the patient.<br> | |
**Family Members**: Relevant medical conditions or genetic disorders in the patientβs family (e.g. cancer, heart disease).<br> | |
### Clinical Presentation | |
**Symptoms**: Current symptoms of the patient.<br> | |
**Comorbid Conditions**: Other medical conditions of the patient that may influence the treatment.<br> | |
### Medical Assessment | |
**Diagnostic Procedures**: Description of the diagnostic tests or procedures performed (e.g. X-rays, MRIs)<br> | |
**Laboratory Results**: Results foom laboratory test (e.g. blood counts, electrolyte levels)<br> | |
**Pathology Report**: Findings from pathological examinations (e.g. biopsy results)<br> | |
### Diagnosis | |
**Diagnosis**: All levels of diagnosis mentioned in the report.<br> | |
### Treatment | |
**Interventional Therapy**: Medications prescribed to the patient.<br> | |
**Pharmacological Therapy**: Information on surgical or non-surgical interventions performed.<br> | |
### Patient Outcome | |
**Patient Outcome**: Evaluation of the patientβs health status at the end of treatment.<br> | |
""") | |
with gr.Tab("About"): | |
gr.Markdown("""## About | |
LlamaMD is a project developed as part of the "NLP for Social Good" course at TU Berlin. | |
The goal of this project is to perform structured information extraction from clinical reports, helping doctors to have more time for their patients. | |
The system is based on `meta-llama/Llama-3.2-1B-Instruct`, which has been fine-tuned on the ELMTEX dataset. | |
""") | |
if __name__ == "__main__": | |
demo.launch() |