import os import spaces import gradio as gr import torch from peft import PeftModel from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login as hf_login import xgrammar as xgr from pydantic import BaseModel hf_login(token=os.getenv("HF_TOKEN")) model_name = "meta-llama/Llama-3.2-1B-Instruct" device = 'cuda' if torch.cuda.is_available() else 'cpu' model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation='eager', trust_remote_code=True, ) checkpoint = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization" model = PeftModel.from_pretrained(model, checkpoint) # model = model.merge_and_unload() class Person(BaseModel): life_style: str family_history: str social_history: str medical_surgical_history: str signs_symptoms: str comorbidities: str diagnostic_techniques_procedures: str diagnosis: str laboratory_values: str pathology: str pharmacological_therapy: str interventional_therapy: str patient_outcome_assessment: str age: str gender: str config = AutoConfig.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer_info = xgr.TokenizerInfo.from_huggingface( tokenizer, vocab_size=config.vocab_size ) grammar_compiler = xgr.GrammarCompiler(tokenizer_info) compiled_grammar = grammar_compiler.compile_json_schema(Person) xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) prompt = """You are a text extraction system for clinical reports. Please extract relevant clinical information from the report. ### Instructions - Use the JSON Schema given below. - Return only a valid JSON object – no markdown, no comments. - If no relevant facts are given for a field, set its value to "N/A". - If multile relevant facts are given for a field, separate them with "; ". ### JSON Schema { 'life_style': '', 'family_history': '', 'social_history': '', 'medical_surgical_history': '', 'signs_symptoms': '', 'comorbidities': '', 'diagnostic_techniques_procedures': '', 'diagnosis': '', 'laboratory_values': '', 'pathology': '', 'pharmacological_therapy': '', 'interventional_therapy': '', 'patient_outcome_assessment': '', 'age': '', 'gender': '', } ### Clinical Report """ @spaces.GPU(duration=60) def summarize(text): if not text.strip(): return "Please enter some text to summarize." messages = [ {"role": "system", "content": prompt}, {"role": "user", "content": text}, ] model_inputs = tokenizer([text], return_tensors="pt").to(device) generated_ids = model.generate( input_ids=model_inputs["input_ids"], attention_mask = model_inputs["attention_mask"], # num_beams=8, # top_p=0.9, # do_sample=True, # temperature=0.6, max_new_tokens=2048, logits_processor=[xgr_logits_processor] ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return response with gr.Blocks() as demo: gr.Markdown("## 📝 Structured Information Extraction for Clinical Reports") with gr.Row(): input_text = gr.Textbox( label="Clinical Report", autoscroll=False, lines=15, max_lines=15, placeholder="Paste your clinical report here...", ) with gr.Row(): gr.CheckboxGroup( label="Countries", info="Where are they from?", choices=[ "Life Style", "Family History", "Social History", "Medical Surgical History", "Signs and Symptoms", "Comorbidities", "Diagnostic Techniques and Procedures", "Diagnosis", "Laboratory Values", "Pathology", "Pharmacological Therapy", "Interventional Therapy", "Patient Outcome Assessment", "Age", "Gender", ] ) with gr.Row(): summarize_btn = gr.Button("Summarize") with gr.Row(): output_text = gr.Textbox( label="Summary", autoscroll=False, lines=15, max_lines=15, show_copy_button=True, ) with gr.Row(): examples = gr.Examples( label="Examples", examples=[ "A 57-year-old male presented with fever (38.9°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 × 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 × 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan." ], fn=summarize, inputs=input_text, outputs=output_text, cache_examples="lazy", ) summarize_btn.click( fn=summarize, inputs=input_text, outputs=output_text, show_progress=True, ) if __name__ == "__main__": demo.launch()