Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import gc | |
import threading | |
from datetime import datetime | |
import gradio as gr | |
import torch | |
from transformers import pipeline, TextIteratorStreamer | |
import spaces # Import spaces early to enable ZeroGPU support | |
# ------------------------------ | |
# Global Cancellation Event | |
# ------------------------------ | |
cancel_event = threading.Event() | |
# ------------------------------ | |
# Qwen3 Model Definitions | |
# ------------------------------ | |
MODELS = { | |
"Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B - Largest model with highest capabilities"}, | |
"Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B - Good balance of performance and efficiency"}, | |
"Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1.7B", "description": "Qwen3-1.7B - Smaller model for faster responses"}, | |
"Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B - Ultra-lightweight model"} | |
} | |
# Global cache for pipelines to avoid re-loading. | |
PIPELINES = {} | |
def load_pipeline(model_name): | |
""" | |
Load and cache a transformers pipeline for text generation. | |
Tries bfloat16, falls back to float16 or float32 if unsupported. | |
""" | |
global PIPELINES | |
if model_name in PIPELINES: | |
return PIPELINES[model_name] | |
repo = MODELS[model_name]["repo_id"] | |
for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
try: | |
pipe = pipeline( | |
task="text-generation", | |
model=repo, | |
tokenizer=repo, | |
trust_remote_code=True, | |
torch_dtype=dtype, | |
device_map="auto" | |
) | |
PIPELINES[model_name] = pipe | |
return pipe | |
except Exception: | |
continue | |
# Final fallback | |
pipe = pipeline( | |
task="text-generation", | |
model=repo, | |
tokenizer=repo, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
PIPELINES[model_name] = pipe | |
return pipe | |
def format_conversation(history, system_prompt): | |
""" | |
Flatten chat history and system prompt into a single string. | |
""" | |
prompt = system_prompt.strip() + "\n" | |
for user_msg, assistant_msg in history: | |
prompt += "User: " + user_msg.strip() + "\n" | |
if assistant_msg: # might be None or empty | |
prompt += "Assistant: " + assistant_msg.strip() + "\n" | |
prompt += "Assistant: " | |
return prompt | |
def generate_response(user_input, history, system_prompt, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty): | |
""" | |
Generate a complete response (non-streaming). | |
""" | |
cancel_event.clear() | |
full_history = history.copy() | |
# Format conversation for the model | |
conversation = format_conversation(full_history, system_prompt) | |
try: | |
pipe = load_pipeline(model_name) | |
output = pipe( | |
conversation, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repeat_penalty, | |
return_full_text=False | |
)[0]["generated_text"] | |
# Return the updated history | |
history.append((user_input, output)) | |
return history | |
except Exception as e: | |
history.append((user_input, f"Error: {e}")) | |
return history | |
finally: | |
gc.collect() | |
def cancel_generation(): | |
cancel_event.set() | |
return 'Generation cancelled.' | |
def get_default_system_prompt(): | |
today = datetime.now().strftime('%Y-%m-%d') | |
return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud. | |
Today is {today}. | |
Be concise, accurate, and helpful in your responses.""" | |
# CSS for improved visual style | |
css = """ | |
.gradio-container { | |
background-color: #f5f7fb !important; | |
} | |
.qwen-header { | |
background: linear-gradient(90deg, #0099FF, #0066CC); | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
text-align: center; | |
color: white; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.qwen-container { | |
border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
background: white; | |
padding: 20px; | |
margin-bottom: 20px; | |
} | |
.controls-container { | |
background: #f0f4fa; | |
border-radius: 10px; | |
padding: 15px; | |
margin-bottom: 15px; | |
} | |
.model-select { | |
border: 2px solid #0099FF !important; | |
border-radius: 8px !important; | |
} | |
.button-primary { | |
background-color: #0099FF !important; | |
color: white !important; | |
} | |
.button-secondary { | |
background-color: #6c757d !important; | |
color: white !important; | |
} | |
.footer { | |
text-align: center; | |
margin-top: 20px; | |
font-size: 0.8em; | |
color: #666; | |
} | |
""" | |
# Function to get just the model name from the dropdown selection | |
def get_model_name(full_selection): | |
return full_selection.split(" - ")[0] | |
# ------------------------------ | |
# Gradio UI | |
# ------------------------------ | |
with gr.Blocks(title="Qwen3 Chat", css=css) as demo: | |
gr.HTML(""" | |
<div class="qwen-header"> | |
<h1>🤖 Qwen3 Chat</h1> | |
<p>Interact with Alibaba Cloud's Qwen3 language models</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(elem_classes="qwen-container"): | |
model_dd = gr.Dropdown( | |
label="Select Qwen3 Model", | |
choices=[f"{k} - {v['description']}" for k, v in MODELS.items()], | |
value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}", | |
elem_classes="model-select" | |
) | |
with gr.Group(elem_classes="controls-container"): | |
gr.Markdown("### ⚙️ Generation Parameters") | |
sys_prompt = gr.Textbox(label="System Prompt", lines=5, value=get_default_system_prompt()) | |
with gr.Row(): | |
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") | |
with gr.Row(): | |
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") | |
with gr.Row(): | |
k = gr.Slider(1, 100, value=40, step=1, label="Top-K") | |
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") | |
clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary") | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary") | |
gr.HTML(""" | |
<div class="footer"> | |
<p>Qwen3 models developed by Alibaba Cloud. Interface powered by Gradio and ZeroGPU.</p> | |
</div> | |
""") | |
# Define event handlers | |
def user_input(user_message, history): | |
return "", history + [(user_message, None)] | |
def bot_response(history, sys_prompt, model, max_tok, temp, k, p, rp): | |
user_message = history[-1][0] | |
bot_message = generate_response( | |
user_message, | |
history[:-1], | |
sys_prompt, | |
get_model_name(model), | |
max_tok, | |
temp, | |
k, | |
p, | |
rp | |
)[-1][1] | |
history[-1] = (user_message, bot_message) | |
return history | |
# Connect everything | |
submit_btn.click( | |
user_input, | |
[txt, chatbot], | |
[txt, chatbot], | |
queue=False | |
).then( | |
bot_response, | |
[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
[chatbot] | |
) | |
txt.submit( | |
user_input, | |
[txt, chatbot], | |
[txt, chatbot], | |
queue=False | |
).then( | |
bot_response, | |
[chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], | |
[chatbot] | |
) | |
clear_btn.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.launch() |