Spaces:
Runtime error
Runtime error
import os | |
import ast | |
import spaces | |
import gradio as gr | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login as hf_login | |
import xgrammar as xgr | |
from pydantic import BaseModel | |
hf_login(token=os.getenv("HF_TOKEN")) | |
model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization-FP32" | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_name, | |
# device_map="auto", | |
# attn_implementation='eager', | |
# trust_remote_code=True, | |
# ) | |
class Person(BaseModel): | |
life_style: str | |
family_history: str | |
social_history: str | |
medical_surgical_history: str | |
signs_symptoms: str | |
comorbidities: str | |
diagnostic_techniques_procedures: str | |
diagnosis: str | |
laboratory_values: str | |
pathology: str | |
pharmacological_therapy: str | |
interventional_therapy: str | |
patient_outcome_assessment: str | |
age: str | |
gender: str | |
config = AutoConfig.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer_info = xgr.TokenizerInfo.from_huggingface( | |
tokenizer, vocab_size=config.vocab_size | |
) | |
grammar_compiler = xgr.GrammarCompiler(tokenizer_info) | |
compiled_grammar = grammar_compiler.compile_json_schema(Person) | |
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) | |
prompt = """You are a text extraction system for clinical reports. | |
Please extract relevant clinical information from the report. | |
### Instructions | |
- Use the JSON Schema given below. | |
- Return only a valid JSON object β no markdown, no comments. | |
- If no relevant facts are given for a field, set its value to "N/A". | |
- If multile relevant facts are given for a field, separate them with "; ". | |
### JSON Schema | |
{ | |
'life_style': '', | |
'family_history': '', | |
'social_history': '', | |
'medical_surgical_history': '', | |
'signs_symptoms': '', | |
'comorbidities': '', | |
'diagnostic_techniques_procedures': '', | |
'diagnosis': '', | |
'laboratory_values': '', | |
'pathology': '', | |
'pharmacological_therapy': '', | |
'interventional_therapy': '', | |
'patient_outcome_assessment': '', | |
'age': '', | |
'gender': '', | |
} | |
### Clinical Report | |
""" | |
def generate_html_tables(data, selected_fields): | |
key_label_map = { | |
'age': 'Age', | |
'gender': 'Gender', | |
'life_style': 'Lifestyle', | |
'social_history': 'Social Background', | |
'medical_surgical_history': 'Personal', | |
'family_history': 'Family Members', | |
'signs_symptoms': 'Symptoms', | |
'comorbidities': 'Comorbid Conditions', | |
'diagnostic_techniques_procedures': 'Diagnostic Procedures', | |
'laboratory_values': 'Laboratory Results', | |
'pathology': 'Pathology Report', | |
'diagnosis': 'Diagnosis', | |
'interventional_therapy': 'Interventional Therapy', | |
'pharmacological_therapy': 'Pharmacological Therapy', | |
'patient_outcome_assessment': 'Patient Outcome', | |
} | |
label_key_map = {v: k for k, v in key_label_map.items()} | |
categories = { | |
"Personal Information": ["Age", "Gender", "Lifestyle", "Social Background"], | |
"Medical History": ["Personal", "Family Members"], | |
"Clinical Presentation": ["Symptoms", "Comorbid Conditions"], | |
"Medical Assessment": ["Diagnostic Procedures", "Laboratory Results", "Pathology Report"], | |
"Diagnosis": ["Diagnosis"], | |
"Treatment": ["Interventional Therapy", "Pharmacological Therapy"], | |
"Patient Outcome": ["Patient Outcome"], | |
} | |
def format_bullets(value): | |
items = [item.strip() for item in value.split(";") if item.strip()] | |
if not items: | |
return "<i>Not Available</i>" | |
if len(items) == 1: | |
return items[0] | |
return "<ul style='margin: 0; padding-left: 1.2em'>" + "".join(f"<li>{item}</li>" for item in items) + "</ul>" | |
table_style = ( | |
"width: 100%;" | |
"table-layout: fixed;" | |
"border-collapse: collapse;" | |
"word-wrap: break-word;" | |
"height: 100%;" | |
) | |
th_td_style_first = ( | |
"padding: 8px;" | |
"border: 1px solid #ccc;" | |
"vertical-align: top;" | |
"text-align: left;" | |
"height: 30px;" | |
"overflow: hidden;" | |
) | |
th_td_style_other = ( | |
"padding: 8px;" | |
"border: 1px solid #ccc;" | |
"vertical-align: top;" | |
"text-align: left;" | |
) | |
html_tables = [] | |
for section, labels in categories.items(): | |
section_fields = [label for label in labels if label in selected_fields] | |
if section_fields: | |
table_html = f"<h3 style='margin-bottom: 0.5em;'>{section}</h3>" | |
table_html += f"<table style='{table_style}'>" | |
table_html += f"<tr><th style='{th_td_style_first}; width: 150px;'>Field</th><th style='{th_td_style_first};'>Details</th></tr>" | |
for label in section_fields: | |
key = label_key_map[label] | |
value = data.get(key, "N/A") | |
details = "<i>Not Available</i>" if value == "N/A" else format_bullets(value) | |
table_html += f"<tr><td style='{th_td_style_other}; width: 150px;'><b>{label}</b></td><td style='{th_td_style_other}'>{details}</td></tr>" | |
table_html += "</table>" | |
html_tables.append(table_html) | |
i = 0 | |
grouped_html = "" | |
while i < len(html_tables): | |
num_per_row = 2 if i < 4 else 3 | |
row_tables = html_tables[i:i+num_per_row] | |
grouped_html += ( | |
"<div style='display: flex; gap: 1em; margin-bottom: 2em; align-items: stretch;'>" | |
) | |
for table in row_tables: | |
grouped_html += ( | |
"<div style='flex: 1; min-width: 0; display: flex; flex-direction: column;'>" | |
f"{table}" | |
"</div>" | |
) | |
grouped_html += "</div>" | |
i += num_per_row | |
return f"<div style='font-family: sans-serif;'>{grouped_html}</div>" | |
def summarize( | |
text, | |
personal_info, | |
medical_history, | |
clinical_presentation, | |
medical_assessment, | |
diagnosis, | |
treatment, | |
patient_outcome, | |
): | |
if not text.strip(): | |
return "Please enter some text to summarize." | |
messages = [ | |
{"role": "system", "content": prompt.strip()}, | |
{"role": "user", "content": text.strip()}, | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False, # only relevant for qwen | |
) | |
""" | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
input_ids=model_inputs["input_ids"], | |
attention_mask = model_inputs["attention_mask"], | |
# num_beams=8, | |
# top_p=0.9, | |
# do_sample=True, | |
# temperature=0.6, | |
# min_new_tokens=50, | |
max_new_tokens=2048, | |
logits_processor=[xgr_logits_processor] | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
""" | |
response = ["{'life_style': 'N/A', 'family_history': 'N/A', 'social_history': 'N/A', 'medical_surgical_history': 'MΓ©niΓ©βdisease; Hypothyroidism; Type B1 thymoma; Adrenal insufficiency; Sputum stasis; Pulmonary infection', 'signs_symptoms': 'Tiredness; Muscle weakness; Cachexia; Decreased muscular strength; Bilateral rigid auricles; Calcified auricles', 'comorbidities': 'MΓ©niΓ©βdisease; Hypothyroidism', 'diagnostic_techniques_procedures': 'Chest X-ray; Magnetic resonance imaging scan; Metyrapone test', 'diagnosis': 'Type B1 thymoma; Hypothyroidism; Myasthenia gravis; Adrenal insufficiency; Pulmonary infection', 'laboratory_values': 'Cortisol: 51nmol/L; TSH: 11.9mU/L; Free T4: 11.2pmol/L; LH: 6.5U/L; FSH: 10.9U/L; Testosterone: 3.4nmol/L; Prolactin: 0.49U/L; IGF-1: 7.8nmol/L', 'pathology': 'Pituitary hypoplasia', 'pharmacological_therapy': 'Levothyroxine; Betahistine; Phenylephrine; Norepinephrine; Hydrocortisone', 'interventional_therapy': 'Surgery; Tracheostomy', 'patient_outcome_assessment': 'Discharged to the ward; Weaned off ventilator; Discharged to ward on 21st POD', 'age': '68 year', 'gender': 'Male'}"] | |
try: | |
data = ast.literal_eval(response[0]) | |
except: | |
data = { | |
'life_style': '', | |
'family_history': '', | |
'social_history': '', | |
'medical_surgical_history': '', | |
'signs_symptoms': '', | |
'comorbidities': '', | |
'diagnostic_techniques_procedures': '', | |
'diagnosis': '', | |
'laboratory_values': '', | |
'pathology': '', | |
'pharmacological_therapy': '', | |
'interventional_therapy': '', | |
'patient_outcome_assessment': '', | |
'age': '', | |
'gender': '', | |
} | |
selected_fields = [] | |
selected_fields += personal_info | |
selected_fields += medical_history | |
selected_fields += clinical_presentation | |
selected_fields += medical_assessment | |
selected_fields += diagnosis | |
selected_fields += treatment | |
selected_fields += patient_outcome | |
return generate_html_tables(data, selected_fields) | |
with gr.Blocks() as demo: | |
# need to be combined with `hf_oauth: true` in README.md | |
# button = gr.LoginButton("Sign in") | |
with gr.Column(): | |
gr.HTML(""" | |
<div align="center"> | |
<img src="https://huggingface.co/spaces/gregorlied/playground/resolve/main/assets/LlamaMD-logo.png" alt="LlamaMD Logo" width="120" style="margin-bottom: 10px;"> | |
<h2><strong>LlamaMD</strong></h2> | |
<p><em>Structured Information Extraction from Clinical Reports</em></p> | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.Tab("LLamaMD"): | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Clinical Report", | |
autoscroll=False, | |
lines=15, | |
max_lines=15, | |
placeholder="Paste your clinical report here...", | |
value="A 57-year-old male presented with fever (38.9Β°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 Γ 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 Γ 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan.", | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
personal_info = gr.CheckboxGroup( | |
label="Personal Information", | |
choices=[ | |
"Age", | |
"Gender", | |
"Lifestyle", | |
"Social Background", | |
], | |
value=[ | |
"Age", | |
"Gender", | |
"Lifestyle", | |
"Social Background", | |
], | |
) | |
with gr.Column(): | |
medical_history = gr.CheckboxGroup( | |
label="Medical History", | |
choices=[ | |
"Personal", | |
"Family Members", | |
], | |
value=[ | |
"Personal", | |
"Family Members", | |
], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
clinical_presentation = gr.CheckboxGroup( | |
label="Clinical Presentation", | |
choices=[ | |
"Symptoms", | |
"Comorbid Conditions", | |
], | |
value=[ | |
"Symptoms", | |
"Comorbid Conditions", | |
], | |
) | |
with gr.Column(): | |
medical_assessment = gr.CheckboxGroup( | |
label="Medical Assessment", | |
choices=[ | |
"Diagnostic Procedures", | |
"Laboratory Results", | |
"Pathology Report", | |
], | |
value=[ | |
"Diagnostic Procedures", | |
"Laboratory Results", | |
"Pathology Report", | |
], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
diagnosis = gr.CheckboxGroup( | |
label="Diagnosis", | |
choices=[ | |
"Diagnosis", | |
], | |
value=[ | |
"Diagnosis", | |
], | |
) | |
with gr.Column(): | |
treatment = gr.CheckboxGroup( | |
label="Treatment", | |
choices=[ | |
"Interventional Therapy", | |
"Pharmacological Therapy", | |
], | |
value=[ | |
"Interventional Therapy", | |
"Pharmacological Therapy", | |
], | |
) | |
with gr.Column(): | |
patient_outcome = gr.CheckboxGroup( | |
label="Patient Outcome", | |
choices=[ | |
"Patient Outcome", | |
], | |
value=[ | |
"Patient Outcome", | |
], | |
) | |
with gr.Row(): | |
summarize_btn = gr.Button("Extract") | |
with gr.Row(): | |
output_text = gr.HTML() | |
summarize_btn.click( | |
fn=summarize, | |
inputs=[input_text, personal_info, medical_history, clinical_presentation, medical_assessment, diagnosis, treatment, patient_outcome], | |
outputs=output_text, | |
show_progress=True, | |
) | |
with gr.Tab("Help"): | |
gr.Markdown("""### Personal Information | |
**Age**: Age of the patient.<br> | |
**Gender**: Gender of the patient.<br> | |
**Lifestyle**: Daily habits and activities of the patient (e.g. alcohol consumption, diet, smoking status).<br> | |
**Social Background**: Social factors of the patient (e.g. housing situation, marital status).<br> | |
### Medical History | |
**Personal**: Past medical conditions, previous surgeries or treatments of the patient.<br> | |
**Family Members**: Relevant medical conditions or genetic disorders in the patientβs family (e.g. cancer, heart disease).<br> | |
### Clinical Presentation | |
**Symptoms**: Current symptoms of the patient.<br> | |
**Comorbid Conditions**: Other medical conditions of the patient that may influence the treatment.<br> | |
### Medical Assessment | |
**Diagnostic Procedures**: Description of the diagnostic tests or procedures performed (e.g. X-rays, MRIs)<br> | |
**Laboratory Results**: Results foom laboratory test (e.g. blood counts, electrolyte levels)<br> | |
**Pathology Report**: Findings from pathological examinations (e.g. biopsy results)<br> | |
### Diagnosis | |
**Diagnosis**: All levels of diagnosis mentioned in the report.<br> | |
### Treatment | |
**Interventional Therapy**: Medications prescribed to the patient.<br> | |
**Pharmacological Therapy**: Information on surgical or non-surgical interventions performed.<br> | |
### Patient Outcome | |
**Patient Outcome**: Evaluation of the patientβs health status at the end of treatment.<br> | |
""") | |
with gr.Tab("About"): | |
gr.Markdown("""LlamaMD is a project developed as part of the "NLP for Social Good" course at TU Berlin. | |
The goal of this project is to perform structured information extraction from clinical reports, helping doctors to have more time for their patients. | |
The system is based on `meta-llama/Llama-3.2-1B-Instruct`, which has been fine-tuned on the ELMTEX dataset. | |
""") | |
if __name__ == "__main__": | |
demo.launch() |