gregorlied's picture
Update app.py
422572e verified
raw
history blame
7.18 kB
import os
import spaces
import gradio as gr
import torch
from peft import PeftModel
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login as hf_login
import xgrammar as xgr
from pydantic import BaseModel
hf_login(token=os.getenv("HF_TOKEN"))
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation='eager',
trust_remote_code=True,
)
# checkpoint = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization"
# model = PeftModel.from_pretrained(model, checkpoint)
# model = model.merge_and_unload()
class Person(BaseModel):
life_style: str
family_history: str
social_history: str
medical_surgical_history: str
signs_symptoms: str
comorbidities: str
diagnostic_techniques_procedures: str
diagnosis: str
laboratory_values: str
pathology: str
pharmacological_therapy: str
interventional_therapy: str
patient_outcome_assessment: str
age: str
gender: str
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=config.vocab_size
)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(Person)
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
prompt = """You are a text extraction system for clinical reports.
Please extract relevant clinical information from the report.
### Instructions
- Use the JSON Schema given below.
- Return only a valid JSON object – no markdown, no comments.
- If no relevant facts are given for a field, set its value to "N/A".
- If multile relevant facts are given for a field, separate them with "; ".
### JSON Schema
{
'life_style': '',
'family_history': '',
'social_history': '',
'medical_surgical_history': '',
'signs_symptoms': '',
'comorbidities': '',
'diagnostic_techniques_procedures': '',
'diagnosis': '',
'laboratory_values': '',
'pathology': '',
'pharmacological_therapy': '',
'interventional_therapy': '',
'patient_outcome_assessment': '',
'age': '',
'gender': '',
}
### Clinical Report
"""
@spaces.GPU(duration=60)
def summarize(text):
if not text.strip():
return "Please enter some text to summarize."
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
]
model_inputs = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=model_inputs["input_ids"],
attention_mask = model_inputs["attention_mask"],
# num_beams=8,
# top_p=0.9,
# do_sample=True,
# temperature=0.6,
# min_new_tokens=50,
max_new_tokens=2048,
logits_processor=[xgr_logits_processor]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return response[0]
with gr.Blocks() as demo:
gr.Markdown("## πŸ“ Structured Information Extraction for Clinical Reports")
with gr.Row():
input_text = gr.Textbox(
label="Clinical Report",
autoscroll=False,
lines=15,
max_lines=15,
placeholder="Paste your clinical report here...",
)
with gr.Row():
gr.CheckboxGroup(
label="General Information",
choices=[
"Age",
"Gender",
"Life Style",
"Social History",
]
)
gr.CheckboxGroup(
label="Medical History",
choices=[
"Family History",
"Medical Surgical History",
]
)
gr.CheckboxGroup(
label="Signs and Symptoms",
choices=[
"Signs and Symptoms",
"Comorbidities",
]
)
with gr.Row():
gr.CheckboxGroup(
label="Medical Assesments",
choices=[
"Diagnostic Techniques and Procedures",
"Laboratory Values",
"Pathology",
"Diagnosis",
]
)
gr.CheckboxGroup(
label="Therapy and Results",
choices=[
"Pharmacological Therapy",
"Interventional Therapy",
"Patient Outcome Assessment",
]
)
with gr.Row():
summarize_btn = gr.Button("Summarize")
with gr.Row():
output_text = gr.Textbox(
label="Summary",
autoscroll=False,
lines=15,
max_lines=15,
show_copy_button=True,
)
with gr.Row():
examples = gr.Examples(
label="Examples",
examples=[
"A 57-year-old male presented with fever (38.9Β°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 Γ— 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 Γ— 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan."
],
fn=summarize,
inputs=input_text,
outputs=output_text,
cache_examples="lazy",
)
summarize_btn.click(
fn=summarize,
inputs=input_text,
outputs=output_text,
show_progress=True,
)
if __name__ == "__main__":
demo.launch()