File size: 2,708 Bytes
ef37daa
3d08dbc
 
e1ff28f
3d08dbc
 
 
 
 
 
 
 
0ce6fc9
3d08dbc
 
 
 
 
 
 
 
 
 
fe44201
 
3d08dbc
fe44201
3d08dbc
fe44201
3d08dbc
b55e187
3d08dbc
 
 
 
 
 
e1ff28f
3d08dbc
 
e1ff28f
3d08dbc
 
 
 
 
 
 
 
e1ff28f
3d08dbc
 
 
 
 
 
 
e1ff28f
56d5550
 
 
 
 
 
 
e0b816f
56d5550
3d08dbc
e1ff28f
3d08dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe44201
3d08dbc
fe44201
3d08dbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Initialize model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_response(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Prepare conversation history
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    messages.append({"role": "user", "content": message})

    # Convert messages to model input format
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Prepare model inputs
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Generate response
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True
    )

    # Extract generated text
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    yield response

# Custom CSS for the Gradio interface
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
body, .gradio-container {
    font-family: 'Inter', sans-serif;
}
"""

# System message
system_message = """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""

# Gradio chat interface
demo = gr.ChatInterface(
    generate_response,
    additional_inputs=[
        gr.Textbox(
            value=system_message,
            visible=False,
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    css=custom_css
)

# Launch the demo
if __name__ == "__main__":
    demo.launch()