File size: 3,649 Bytes
ef37daa
3d08dbc
 
83e20b0
e1ff28f
3d08dbc
 
83e20b0
3d08dbc
 
 
 
 
 
83e20b0
 
 
 
 
 
 
 
 
 
 
 
 
 
0ce6fc9
3d08dbc
 
 
 
 
 
 
 
 
 
fe44201
 
3d08dbc
fe44201
3d08dbc
fe44201
3d08dbc
b55e187
3d08dbc
 
 
 
 
 
e1ff28f
83e20b0
 
 
 
 
 
 
 
 
 
 
 
 
e1ff28f
83e20b0
 
 
e1ff28f
83e20b0
56d5550
 
 
 
 
83e20b0
 
 
 
 
 
 
56d5550
e0b816f
56d5550
3d08dbc
e1ff28f
3d08dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe44201
3d08dbc
fe44201
83e20b0
3d08dbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

# Initialize model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model and tokenizer loaded!")

def simulate_typing(text, min_chars_per_sec=20, max_chars_per_sec=60):
    """Simulate typing animation with variable speed."""
    full_text = ""
    words = text.split()
    for i, word in enumerate(words):
        full_text += word
        if i < len(words) - 1:
            full_text += " "
        # Vary typing speed between min and max chars per second
        delay = 1 / (min_chars_per_sec + (max_chars_per_sec - min_chars_per_sec) * torch.rand(1).item())
        time.sleep(delay)
        yield full_text

def generate_response(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Prepare conversation history
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    messages.append({"role": "user", "content": message})

    # Convert messages to model input format
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Prepare model inputs and generate in one go
    with torch.inference_mode():
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
        generated_ids = generated_ids[0, len(model_inputs.input_ids[0]):]
        response = tokenizer.decode(generated_ids, skip_special_tokens=True)

    # Return response with typing animation
    for partial_response in simulate_typing(response):
        yield partial_response

# Custom CSS with typing cursor animation
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
body, .gradio-container {
    font-family: 'Inter', sans-serif;
}
.typing-cursor::after {
    content: '|';
    animation: blink 1s step-start infinite;
}
@keyframes blink {
    50% { opacity: 0; }
}
"""

# System message
system_message = """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""

# Gradio chat interface
demo = gr.ChatInterface(
    generate_response,
    additional_inputs=[
        gr.Textbox(
            value=system_message,
            visible=False,
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    css=custom_css
)

# Launch the demo
if __name__ == "__main__":
    demo.queue()  # Enable queuing for better handling of multiple requests
    demo.launch()