File size: 2,831 Bytes
51cfed5
776e30f
59e7020
776e30f
51cfed5
776e30f
59e7020
 
 
776e30f
 
59e7020
776e30f
 
 
 
 
51cfed5
1e560c3
 
26ba426
 
1e560c3
51cfed5
776e30f
 
 
 
 
51cfed5
 
776e30f
51cfed5
776e30f
59e7020
 
 
 
 
776e30f
51cfed5
1e560c3
776e30f
 
 
1e560c3
776e30f
 
59e7020
776e30f
 
 
59e7020
776e30f
 
 
 
 
 
 
1e560c3
26ba426
776e30f
 
26ba426
 
776e30f
 
 
26ba426
51cfed5
776e30f
51cfed5
 
 
d4a4ec2
26ba426
9b4871f
776e30f
51cfed5
 
 
 
26ba426
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread

# Load model and tokenizer
model_name = "GoofyLM/gonzalez-v1"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set pad token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Define a custom chat template if one is not available
if tokenizer.chat_template is None:
    # Basic ChatML-style template
    tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}<|assistant|>\n{% endif %}"

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Build conversation messages
    messages = [{"role": "system", "content": system_message}]
    
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    messages.append({"role": "user", "content": message})
    
    # Format prompt using chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Set up streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Configure generation parameters
    do_sample = temperature > 0 or top_p < 1.0
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=do_sample,
        pad_token_id=tokenizer.pad_token_id
    )
    
    # Start generation in separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Stream response
    response = ""
    for token in streamer:
        response += token
        yield response

# Create Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="", label="System message"),
        gr.Slider(1, 215, value=72, label="Max new tokens"),
        gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo. launch()