Spaces:
Sleeping
Sleeping
File size: 2,708 Bytes
ef37daa 3d08dbc e1ff28f 3d08dbc 0ce6fc9 3d08dbc fe44201 3d08dbc fe44201 3d08dbc fe44201 3d08dbc b55e187 3d08dbc e1ff28f 3d08dbc e1ff28f 3d08dbc e1ff28f 3d08dbc e1ff28f 56d5550 e0b816f 56d5550 3d08dbc e1ff28f 3d08dbc fe44201 3d08dbc fe44201 3d08dbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Initialize model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_response(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Prepare conversation history
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# Convert messages to model input format
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare model inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate response
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Extract generated text
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
yield response
# Custom CSS for the Gradio interface
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
body, .gradio-container {
font-family: 'Inter', sans-serif;
}
"""
# System message
system_message = """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""
# Gradio chat interface
demo = gr.ChatInterface(
generate_response,
additional_inputs=[
gr.Textbox(
value=system_message,
visible=False,
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
css=custom_css
)
# Launch the demo
if __name__ == "__main__":
demo.launch() |