File size: 3,819 Bytes
d84bb55 cbd44c9 eec5410 6b4f26c 78f48b0 00f162e 6b4f26c 78f48b0 6b4f26c eec5410 7a9d8b4 00f162e d84bb55 00f162e a6e0ce3 78f48b0 eec5410 78f48b0 6b4f26c 78f48b0 6b4f26c 78f48b0 6b4f26c 78f48b0 6b4f26c 78f48b0 6b4f26c 78f48b0 6b4f26c a6e0ce3 6b4f26c eec5410 a6e0ce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import spaces
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login as hf_login
import xgrammar as xgr
from pydantic import BaseModel
os.environ["VLLM_LOGGING_LEVEL"]="DEBUG"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"
hf_login(token=os.getenv("HF_TOKEN"))
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32, device_map="auto"
)
class Person(BaseModel):
life_style: str
family_history: str
social_history: str
medical_surgical_history: str
signs_symptoms: str
comorbidities: str
diagnostic_techniques_procedures: str
diagnosis: str
laboratory_values: str
pathology: str
pharmacological_therapy: str
interventional_therapy: str
patient_outcome_assessment: str
age: str
gender: str
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=config.vocab_size
)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(Person)
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
prompt = """You are a text extraction system for clinical reports.
Please extract relevant clinical information from the report.
### Instructions
- Use the JSON Schema given below.
- Return only a valid JSON object – no markdown, no comments.
- If no relevant facts are given for a field, set its value to "N/A".
- If multile relevant facts are given for a field, separate them with "; ".
### JSON Schema
{
'life_style': '',
'family_history': '',
'social_history': '',
'medical_surgical_history': '',
'signs_symptoms': '',
'comorbidities': '',
'diagnostic_techniques_procedures': '',
'diagnosis': '',
'laboratory_values': '',
'pathology': '',
'pharmacological_therapy': '',
'interventional_therapy': '',
'patient_outcome_assessment': '',
'age': '',
'gender': '',
}
### Clinical Report
"""
@spaces.GPU(duration=60)
def summarize(text):
if not text.strip():
return "Please enter some text to summarize."
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
]
model_inputs = tokenizer([text], return_tensors="pt").to('cuda')
generated_ids = model.generate(
input_ids=model_inputs["input_ids"],
attention_mask = model_inputs["attention_mask"],
# num_beams=8,
# top_p=0.9,
# do_sample=True,
# temperature=0.6,
max_new_tokens=2048,
logits_processor=[xgr_logits_processor]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return response[0]
with gr.Blocks() as demo:
gr.Markdown("## 📝 Summarization for News, SciTLDR and Dialog Texts")
with gr.Row():
input_text = gr.Textbox(
label="Input Text",
autoscroll=False,
lines=15,
max_lines=15,
placeholder="Paste your article or paragraph here...",
)
output_text = gr.Textbox(
label="Summary",
autoscroll=False,
lines=15,
max_lines=15,
show_copy_button=True,
)
with gr.Row():
summarize_btn = gr.Button("Summarize")
summarize_btn.click(
fn=summarize,
inputs=input_text,
outputs=output_text,
show_progress=True,
)
if __name__ == "__main__":
demo.launch() |