Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from threading import Thread | |
import spaces | |
class ChatInterface: | |
def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
def format_chat_prompt(self, message, history, system_message): | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
# Format messages according to model's expected chat template | |
prompt = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
return prompt | |
def generate_response( | |
self, | |
message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
prompt = self.format_chat_prompt(message, history, system_message) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
# Setup streamer | |
streamer = TextIteratorStreamer( | |
self.tokenizer, | |
timeout=10.0, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
# Generate in a separate thread to enable streaming | |
generation_kwargs = dict( | |
inputs=inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Stream the response | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
yield response | |
def create_demo(): | |
chat_interface = ChatInterface() | |
demo = gr.ChatInterface( | |
chat_interface.generate_response, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a friendly Chatbot.", | |
label="System message" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=512, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() |