Spaces:
Running
Running
import os | |
import time | |
import random | |
import gradio as gui | |
from gradio.themes.utils import colors | |
from dataclasses import dataclass | |
from typing import Dict, Iterator, List, Literal, Optional, TypedDict, NotRequired | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import torch | |
# Custom theme for the Gradio interface | |
custom_theme = gui.themes.Default( | |
primary_hue=colors.blue, | |
secondary_hue=colors.green, | |
neutral_hue=colors.gray, | |
font=[gui.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
).set( | |
body_background_fill="#FFFFFF", | |
body_text_color="#1F2937", | |
button_primary_background_fill="#2D7FF9", | |
button_primary_background_fill_hover="#1A56F0", | |
button_secondary_background_fill="#10B981", | |
button_secondary_background_fill_hover="#059669", | |
block_title_text_color="#6B7280", | |
block_label_text_color="#6B7280", | |
background_fill_primary="#F9FAFB", | |
background_fill_secondary="#F3F4F6", | |
) | |
class UserMessage: | |
content: str | |
role: Literal["user", "assistant"] | |
metadata: Optional[Dict] = None | |
options: Optional[List[Dict]] = None | |
class Metadata(TypedDict): | |
title: NotRequired[str] | |
id: NotRequired[int | str] | |
parent_id: NotRequired[int | str] | |
log: NotRequired[str] | |
duration: NotRequired[float] | |
status: NotRequired[Literal["pending", "done"]] | |
MODEL_IDENTIFIER = "smol-ai/SmolLM2-135M-Instruct" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_IDENTIFIER, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return model, tokenizer | |
print("Loading model and tokenizer...") | |
model_instance, tokenizer_instance = load_model() | |
print("Model and tokenizer loaded!") | |
def build_conversation_prompt(current_message: str, history: List[UserMessage]) -> str: | |
conversation_history = [ | |
f"{message.role.upper()}: {message.content}" for message in history | |
] | |
conversation_history.append(f"USER: {current_message}") | |
conversation_history.append("ASSISTANT: ") | |
return "\n".join(conversation_history) | |
def stream_chat_response(user_input: str, history: List[UserMessage]) -> Iterator[List[UserMessage]]: | |
prompt_text = build_conversation_prompt(user_input, history) | |
inputs = tokenizer_instance(prompt_text, return_tensors="pt").to(model_instance.device) | |
response_streamer = TextIteratorStreamer( | |
tokenizer_instance, | |
timeout=10.0, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generation_params = { | |
"input_ids": inputs.input_ids, | |
"attention_mask": inputs.attention_mask, | |
"max_new_tokens": 512, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"streamer": response_streamer, | |
"do_sample": True, | |
} | |
thread = Thread(target=model_instance.generate, kwargs=generation_params) | |
thread.start() | |
thought_buffer = "" | |
updated_history = history + [UserMessage(role="user", content=user_input)] | |
updated_history.append(create_thinking_message()) | |
yield updated_history | |
for _ in range(random.randint(3, 6)): | |
thought_buffer = update_thoughts(thought_buffer, updated_history) | |
yield updated_history | |
time.sleep(0.5) | |
finalize_thinking(updated_history, thought_buffer) | |
yield updated_history | |
for text_chunk in response_streamer: | |
updated_history[-1] = UserMessage(role="assistant", content=updated_history[-1].content + text_chunk) | |
yield updated_history | |
time.sleep(0.01) | |
def create_thinking_message() -> UserMessage: | |
return UserMessage( | |
role="assistant", | |
content="", | |
metadata={ | |
"title": "🧠 Thinking...", | |
"status": "pending" | |
} | |
) | |
def update_thoughts(thought_buffer: str, updated_history: List[UserMessage]) -> str: | |
thought_segments = [ | |
"Analyzing the user's query...", | |
"Retrieving relevant information...", | |
"Considering different perspectives...", | |
"Formulating a coherent response...", | |
"Checking for accuracy and completeness...", | |
"Organizing thoughts in a logical structure..." | |
] | |
thought_buffer += random.choice(thought_segments) + " " | |
updated_history[-1] = UserMessage( | |
role="assistant", | |
content=thought_buffer, | |
metadata={ | |
"title": "🧠 Thinking...", | |
"status": "pending" | |
} | |
) | |
return thought_buffer | |
def finalize_thinking(updated_history: List[UserMessage], thought_buffer: str): | |
thinking_duration = time.time() - start_time | |
updated_history[-1] = UserMessage( | |
role="assistant", | |
content=thought_buffer, | |
metadata={ | |
"title": "🧠 Thinking Process", | |
"status": "done", | |
"duration": round(thinking_duration, 2) | |
} | |
) | |
updated_history.append(UserMessage(role="assistant", content="")) | |
def reset_chat() -> List[UserMessage]: | |
return [] | |
style_sheet = """ | |
.message-user { | |
background-color: #F3F4F6 !important; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 8px 0; | |
} | |
.message-assistant { | |
background-color: #F9FAFB !important; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 8px 0; | |
border-left: 3px solid #2D7FF9; | |
} | |
.thinking-box { | |
background-color: #F0F9FF !important; | |
border: 1px solid #BAE6FD; | |
border-radius: 6px; | |
} | |
.chat-container { | |
height: calc(100vh - 230px); | |
overflow-y: auto; | |
padding: 16px; | |
} | |
.input-container { | |
position: sticky; | |
bottom: 0; | |
background-color: #FFFFFF; | |
padding: 16px; | |
border-top: 1px solid #E5E7EB; | |
} | |
@media (max-width: 640px) { | |
.chat-container { | |
height: calc(100vh - 200px); | |
} | |
} | |
footer { | |
display: none !important; | |
} | |
""" | |
with gui.Blocks(theme=custom_theme, css=style_sheet) as demo_interface: | |
gui.HTML(""" | |
<div style="text-align: center; margin-bottom: 1rem"> | |
<h1 style="font-size: 2.5rem; font-weight: 600; color: #1F2937">SmolLM2 Chat</h1> | |
<p style="font-size: 1.1rem; color: #6B7280"> | |
Chat with SmolLM2-135M-Instruct: A small but capable AI assistant | |
</p> | |
</div> | |
""") | |
chat_interface = gui.Chatbot( | |
value=[], | |
avatar_images=(None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot.png"), | |
show_label=False, | |
container=True, | |
height=600, | |
elem_classes="chat-container", | |
type="messages" | |
) | |
with gui.Row(elem_classes="input-container"): | |
with gui.Column(scale=20): | |
message_input = gui.Textbox( | |
show_label=False, | |
placeholder="Type your message here...", | |
container=False, | |
lines=2 | |
) | |
with gui.Column(scale=1, min_width=50): | |
send_button = gui.Button("Send", variant="primary") | |
with gui.Row(): | |
clear_button = gui.Button("Clear Chat", variant="secondary") | |
message_input.submit( | |
stream_chat_response, | |
[message_input, chat_interface], | |
[chat_interface], | |
queue=True | |
).then( | |
lambda: "", | |
None, | |
[message_input], | |
queue=False | |
) | |
send_button.click( | |
stream_chat_response, | |
[message_input, chat_interface], | |
[chat_interface], | |
queue=True | |
).then( | |
lambda: "", | |
None, | |
[message_input], | |
queue=False | |
) | |
clear_button.click( | |
reset_chat, | |
None, | |
[chat_interface], | |
queue=False | |
) | |
message_input.submit(lambda: "", None, [message_input]) | |
if __name__ == "__main__": | |
demo_interface.launch( | |
server_name="0.0.0.0", | |
server_port=5000, | |
share=False | |
) |