File size: 3,819 Bytes
d84bb55
cbd44c9
eec5410
6b4f26c
 
78f48b0
00f162e
6b4f26c
78f48b0
6b4f26c
eec5410
7a9d8b4
 
 
00f162e
d84bb55
00f162e
a6e0ce3
78f48b0
 
eec5410
 
78f48b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b4f26c
78f48b0
 
6b4f26c
 
78f48b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b4f26c
 
 
 
 
 
 
 
 
 
 
78f48b0
 
 
 
 
 
 
 
 
 
 
6b4f26c
78f48b0
 
 
 
6b4f26c
78f48b0
 
6b4f26c
a6e0ce3
6b4f26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec5410
 
a6e0ce3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import spaces
import gradio as gr

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login as hf_login

import xgrammar as xgr
from pydantic import BaseModel

os.environ["VLLM_LOGGING_LEVEL"]="DEBUG"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"

hf_login(token=os.getenv("HF_TOKEN"))

model_name = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map="auto"
)

class Person(BaseModel):
  life_style: str
  family_history: str
  social_history: str
  medical_surgical_history: str
  signs_symptoms: str
  comorbidities: str
  diagnostic_techniques_procedures: str
  diagnosis: str
  laboratory_values: str
  pathology: str
  pharmacological_therapy: str
  interventional_therapy: str
  patient_outcome_assessment: str
  age: str
  gender: str
    
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_info = xgr.TokenizerInfo.from_huggingface(
    tokenizer, vocab_size=config.vocab_size
)

grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(Person)
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)

prompt = """You are a text extraction system for clinical reports.
Please extract relevant clinical information from the report.

### Instructions

- Use the JSON Schema given below.
- Return only a valid JSON object – no markdown, no comments.
- If no relevant facts are given for a field, set its value to "N/A".
- If multile relevant facts are given for a field, separate them with "; ".

### JSON Schema

{
  'life_style': '',
  'family_history': '',
  'social_history': '',
  'medical_surgical_history': '',
  'signs_symptoms': '',
  'comorbidities': '',
  'diagnostic_techniques_procedures': '',
  'diagnosis': '',
  'laboratory_values': '',
  'pathology': '',
  'pharmacological_therapy': '',
  'interventional_therapy': '',
  'patient_outcome_assessment': '',
  'age': '',
  'gender': '',
}

### Clinical Report
"""

@spaces.GPU(duration=60)
def summarize(text):
    if not text.strip():
        return "Please enter some text to summarize."

    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": text},
    ]
    
    model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

    generated_ids = model.generate(
        input_ids=model_inputs["input_ids"],
        attention_mask = model_inputs["attention_mask"],
        # num_beams=8,
        # top_p=0.9,
        # do_sample=True,
        # temperature=0.6,
        max_new_tokens=2048,
        logits_processor=[xgr_logits_processor]
    )

    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return response[0]

with gr.Blocks() as demo:
    gr.Markdown("## 📝 Summarization for News, SciTLDR and Dialog Texts")

    with gr.Row():
        input_text = gr.Textbox(
            label="Input Text", 
            autoscroll=False,
            lines=15, 
            max_lines=15, 
            placeholder="Paste your article or paragraph here...",
        )
        output_text = gr.Textbox(
            label="Summary", 
            autoscroll=False,
            lines=15, 
            max_lines=15, 
            show_copy_button=True,
        )

    with gr.Row():
        summarize_btn = gr.Button("Summarize")
        summarize_btn.click(
            fn=summarize, 
            inputs=input_text, 
            outputs=output_text, 
            show_progress=True,
        )

if __name__ == "__main__":
    demo.launch()