File size: 3,748 Bytes
d620330
 
 
 
 
 
3b34fd9
d620330
6904d6f
fa9f7fb
d620330
 
d1e1697
d620330
d1e1697
d620330
 
7a641cc
d620330
7a641cc
 
d620330
 
 
 
 
 
 
 
 
 
 
6156bdc
90f50bb
d620330
 
 
 
 
 
 
b01fc07
d620330
 
 
3b34fd9
 
7a641cc
d620330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd8807
 
d620330
 
 
 
 
 
 
 
 
ec3cf95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import spaces
import asyncio

description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
<br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
"""

model_name = "unsloth/Llama-3.2-1B-Instruct"
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")

provider = TransformersProvider(model, tokenizer, device)
strategy = CreativeWritingStrategy(provider)
creative_sampler = BacktrackSampler(strategy, provider)

def create_chat_template_messages(history, prompt):
    messages = [{"role": "user", "content": prompt}]
    
    for i, (input_text, response_text) in enumerate(history):
        messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
        messages.append({"role": "assistant", "content": response_text})
    
    return messages

@spaces.GPU
def generate_responses(prompt, history):
    messages = create_chat_template_messages(history, prompt)
    wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)

    inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")

    async def custom_sampler_task():
        generated_list = []
        generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
        for token in generator:
            generated_list.append(token)
        return tokenizer.decode(generated_list, skip_special_tokens=True)
        
    custom_output = asyncio.run(custom_sampler_task())
    standard_output = model.generate(inputs, max_length=2048, temperature=1)
    standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)

    return standard_response.strip(), custom_output.strip()

with gr.Blocks(theme=gr.themes.Citrus()) as demo:
    gr.Markdown(description)

    with gr.Row():
        standard_chat = gr.Chatbot(label="Standard Sampler")
        custom_chat = gr.Chatbot(label="Creative Writing Strategy")

    with gr.Row():
        prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)

    examples = [
        "Write me a short story about a talking dog who wants to be a detective.",
        "Tell me a short tale of a dragon who is afraid of heights.",
        "Create a short story where aliens land on Earth, but they just want to throw a party."
    ]

    gr.Examples(examples=examples, inputs=prompt_input)

    submit_button = gr.Button("Submit")

    def update_chat(prompt, standard_history, custom_history):
        standard_response, custom_response = generate_responses(prompt, standard_history)

        standard_history = standard_history + [(prompt, standard_response)]
        custom_history = custom_history + [(prompt, custom_response)]

        return standard_history, custom_history, ""

    prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
    submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])

demo.queue().launch(debug=True)