File size: 2,369 Bytes
d84bb55
cbd44c9
eec5410
6b4f26c
 
 
00f162e
6b4f26c
00f162e
6b4f26c
eec5410
7a9d8b4
 
 
00f162e
d84bb55
00f162e
a6e0ce3
00f162e
 
 
 
 
eec5410
 
6b4f26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e0ce3
6b4f26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec5410
 
a6e0ce3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import spaces
import gradio as gr

import torch
from transformers import AutoTokenizer
from huggingface_hub import login as hf_login

from vllm import LLM
from pydantic import BaseModel

os.environ["VLLM_LOGGING_LEVEL"]="DEBUG"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"

hf_login(token=os.getenv("HF_TOKEN"))

model_name = "meta-llama/Llama-3.2-1B-Instruct"

model = LLM(
    model=model_name,
    dtype=torch.bfloat16,
    trust_remote_code=True,
    enforce_eager=True,
)

class Info(BaseModel):
    name: str
    age: int

json_schema = Info.model_json_schema()
guided_decoding_params = GuidedDecodingParams(json=json_schema)
sampling_params = SamplingParams(
    temperature=0.1, 
    max_tokens=2048, 
    guided_decoding=guided_decoding_params,
)

prompt = "You are a helpful assistant."

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side='right',
    trust_remote_code=True,
)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '<pad>'})

@spaces.GPU(duration=60)
def summarize(text):
    if not text.strip():
        return "Please enter some text to summarize."

    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": text},
    ]
    
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    
    outputs = model.generate([input_text], sampling_params)
    prediction = outputs[0].outputs[0].text
    return prediction

with gr.Blocks() as demo:
    gr.Markdown("## πŸ“ Summarization for News, SciTLDR and Dialog Texts")

    with gr.Row():
        input_text = gr.Textbox(
            label="Input Text", 
            autoscroll=False,
            lines=15, 
            max_lines=15, 
            placeholder="Paste your article or paragraph here...",
        )
        output_text = gr.Textbox(
            label="Summary", 
            autoscroll=False,
            lines=15, 
            max_lines=15, 
            show_copy_button=True,
        )

    with gr.Row():
        summarize_btn = gr.Button("Summarize")
        summarize_btn.click(
            fn=summarize, 
            inputs=input_text, 
            outputs=output_text, 
            show_progress=True,
        )

if __name__ == "__main__":
    demo.launch()