|
import os |
|
import spaces |
|
import gradio as gr |
|
|
|
import torch |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
from huggingface_hub import login as hf_login |
|
|
|
import xgrammar as xgr |
|
from pydantic import BaseModel |
|
|
|
os.environ["VLLM_LOGGING_LEVEL"]="DEBUG" |
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn" |
|
|
|
hf_login(token=os.getenv("HF_TOKEN")) |
|
|
|
model_name = "meta-llama/Llama-3.2-1B-Instruct" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, torch_dtype=torch.float32, device_map="auto" |
|
) |
|
|
|
class Person(BaseModel): |
|
life_style: str |
|
family_history: str |
|
social_history: str |
|
medical_surgical_history: str |
|
signs_symptoms: str |
|
comorbidities: str |
|
diagnostic_techniques_procedures: str |
|
diagnosis: str |
|
laboratory_values: str |
|
pathology: str |
|
pharmacological_therapy: str |
|
interventional_therapy: str |
|
patient_outcome_assessment: str |
|
age: str |
|
gender: str |
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface( |
|
tokenizer, vocab_size=config.vocab_size |
|
) |
|
|
|
grammar_compiler = xgr.GrammarCompiler(tokenizer_info) |
|
compiled_grammar = grammar_compiler.compile_json_schema(Person) |
|
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) |
|
|
|
prompt = """You are a text extraction system for clinical reports. |
|
Please extract relevant clinical information from the report. |
|
|
|
### Instructions |
|
|
|
- Use the JSON Schema given below. |
|
- Return only a valid JSON object β no markdown, no comments. |
|
- If no relevant facts are given for a field, set its value to "N/A". |
|
- If multile relevant facts are given for a field, separate them with "; ". |
|
|
|
### JSON Schema |
|
|
|
{ |
|
'life_style': '', |
|
'family_history': '', |
|
'social_history': '', |
|
'medical_surgical_history': '', |
|
'signs_symptoms': '', |
|
'comorbidities': '', |
|
'diagnostic_techniques_procedures': '', |
|
'diagnosis': '', |
|
'laboratory_values': '', |
|
'pathology': '', |
|
'pharmacological_therapy': '', |
|
'interventional_therapy': '', |
|
'patient_outcome_assessment': '', |
|
'age': '', |
|
'gender': '', |
|
} |
|
|
|
### Clinical Report |
|
""" |
|
|
|
@spaces.GPU(duration=60) |
|
def summarize(text): |
|
if not text.strip(): |
|
return "Please enter some text to summarize." |
|
|
|
messages = [ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": text}, |
|
] |
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to('cuda') |
|
|
|
generated_ids = model.generate( |
|
input_ids=model_inputs["input_ids"], |
|
attention_mask = model_inputs["attention_mask"], |
|
|
|
|
|
|
|
|
|
max_new_tokens=2048, |
|
logits_processor=[xgr_logits_processor] |
|
) |
|
|
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
|
return response[0] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π Summarization for News, SciTLDR and Dialog Texts") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
autoscroll=False, |
|
lines=15, |
|
max_lines=15, |
|
placeholder="Paste your article or paragraph here...", |
|
) |
|
output_text = gr.Textbox( |
|
label="Summary", |
|
autoscroll=False, |
|
lines=15, |
|
max_lines=15, |
|
show_copy_button=True, |
|
) |
|
|
|
with gr.Row(): |
|
summarize_btn = gr.Button("Summarize") |
|
summarize_btn.click( |
|
fn=summarize, |
|
inputs=input_text, |
|
outputs=output_text, |
|
show_progress=True, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |