File size: 3,617 Bytes
cd5a409
769181d
 
 
 
cd5a409
769181d
 
 
 
 
 
 
 
 
 
0f4223a
769181d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import pprint
import subprocess
from threading import Thread
from transformers import AutoTokenizer, TextIteratorStreamer


result = subprocess.run(["lscpu"], text=True, capture_output=True)
pprint.pprint(result.stdout)


checkpoint = "suriya7/Gemma-2b-SFT"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)


def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
   alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

    inputs = tokenizer(
[
    alpaca_prompt.format(
        "You are an AI assistant. Please ensure that the answers conclude with an end-of-sequence (EOS) token.", # instruction
        user_text, # input goes here
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt",return_dict=True)

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=float(temperature),
        top_k=top_k,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output

def reset_textbox():
    return gr.update(value="")


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                label="User input",
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=1,
                maximum=1000,
                value=250,
                step=1,
                interactive=True,
                label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05,
                maximum=1.0,
                value=0.95,
                step=0.05,
                interactive=True,
                label="Top-p (nucleus sampling)",
            )
            top_k = gr.Slider(
                minimum=1,
                maximum=50,
                value=50,
                step=1,
                interactive=True,
                label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=5.0,
                value=0.8,
                step=0.1,
                interactive=True,
                label="Temperature",
            )

    user_text.submit(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens],
        model_output,
    )
    button_submit.click(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens],
        model_output,
    )

    demo.queue(max_size=32).launch(enable_queue=True, server_name="0.0.0.0")
    # For local use:
    # demo.launch(server_name="0.0.0.0")