whackthejacker commited on
Commit
ecff730
·
verified ·
1 Parent(s): 747abd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -50
app.py CHANGED
@@ -1,64 +1,279 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import gradio as gui
5
+ from gradio.themes.utils import colors
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Iterator, List, Literal, Optional, TypedDict, NotRequired
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from threading import Thread
10
+ import torch
11
 
12
+ # Custom theme for the Gradio interface
13
+ custom_theme = gui.themes.Default(
14
+ primary_hue=colors.blue,
15
+ secondary_hue=colors.green,
16
+ neutral_hue=colors.gray,
17
+ font=[gui.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
18
+ ).set(
19
+ body_background_fill="#FFFFFF",
20
+ body_text_color="#1F2937",
21
+ button_primary_background_fill="#2D7FF9",
22
+ button_primary_background_fill_hover="#1A56F0",
23
+ button_secondary_background_fill="#10B981",
24
+ button_secondary_background_fill_hover="#059669",
25
+ block_title_text_color="#6B7280",
26
+ block_label_text_color="#6B7280",
27
+ background_fill_primary="#F9FAFB",
28
+ background_fill_secondary="#F3F4F6",
29
+ )
30
 
31
+ @dataclass
32
+ class UserMessage:
33
+ content: str
34
+ role: Literal["user", "assistant"]
35
+ metadata: Optional[Dict] = None
36
+ options: Optional[List[Dict]] = None
37
 
38
+ class Metadata(TypedDict):
39
+ title: NotRequired[str]
40
+ id: NotRequired[int | str]
41
+ parent_id: NotRequired[int | str]
42
+ log: NotRequired[str]
43
+ duration: NotRequired[float]
44
+ status: NotRequired[Literal["pending", "done"]]
 
 
45
 
46
+ MODEL_IDENTIFIER = "smol-ai/SmolLM2-135M-Instruct"
 
 
 
 
47
 
48
+ @torch.inference_mode()
49
+ def load_model():
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ MODEL_IDENTIFIER,
53
+ torch_dtype=torch.float16,
54
+ device_map="auto"
55
+ )
56
+ return model, tokenizer
57
 
58
+ print("Loading model and tokenizer...")
59
+ model_instance, tokenizer_instance = load_model()
60
+ print("Model and tokenizer loaded!")
61
 
62
+ def build_conversation_prompt(current_message: str, history: List[UserMessage]) -> str:
63
+ conversation_history = [
64
+ f"{message.role.upper()}: {message.content}" for message in history
65
+ ]
66
+ conversation_history.append(f"USER: {current_message}")
67
+ conversation_history.append("ASSISTANT: ")
68
+ return "\n".join(conversation_history)
 
69
 
70
+ def stream_chat_response(user_input: str, history: List[UserMessage]) -> Iterator[List[UserMessage]]:
71
+ prompt_text = build_conversation_prompt(user_input, history)
72
+ inputs = tokenizer_instance(prompt_text, return_tensors="pt").to(model_instance.device)
73
 
74
+ response_streamer = TextIteratorStreamer(
75
+ tokenizer_instance,
76
+ timeout=10.0,
77
+ skip_prompt=True,
78
+ skip_special_tokens=True
79
+ )
80
 
81
+ generation_params = {
82
+ "input_ids": inputs.input_ids,
83
+ "attention_mask": inputs.attention_mask,
84
+ "max_new_tokens": 512,
85
+ "temperature": 0.7,
86
+ "top_p": 0.9,
87
+ "streamer": response_streamer,
88
+ "do_sample": True,
89
+ }
90
+
91
+ thread = Thread(target=model_instance.generate, kwargs=generation_params)
92
+ thread.start()
93
+
94
+ thought_buffer = ""
95
+ updated_history = history + [UserMessage(role="user", content=user_input)]
96
+ updated_history.append(create_thinking_message())
97
+
98
+ yield updated_history
99
+
100
+ for _ in range(random.randint(3, 6)):
101
+ thought_buffer = update_thoughts(thought_buffer, updated_history)
102
+ yield updated_history
103
+ time.sleep(0.5)
104
+
105
+ finalize_thinking(updated_history, thought_buffer)
106
+ yield updated_history
107
+
108
+ for text_chunk in response_streamer:
109
+ updated_history[-1] = UserMessage(role="assistant", content=updated_history[-1].content + text_chunk)
110
+ yield updated_history
111
+ time.sleep(0.01)
112
+
113
+ def create_thinking_message() -> UserMessage:
114
+ return UserMessage(
115
+ role="assistant",
116
+ content="",
117
+ metadata={
118
+ "title": "🧠 Thinking...",
119
+ "status": "pending"
120
+ }
121
+ )
122
+
123
+ def update_thoughts(thought_buffer: str, updated_history: List[UserMessage]) -> str:
124
+ thought_segments = [
125
+ "Analyzing the user's query...",
126
+ "Retrieving relevant information...",
127
+ "Considering different perspectives...",
128
+ "Formulating a coherent response...",
129
+ "Checking for accuracy and completeness...",
130
+ "Organizing thoughts in a logical structure..."
131
+ ]
132
+ thought_buffer += random.choice(thought_segments) + " "
133
+ updated_history[-1] = UserMessage(
134
+ role="assistant",
135
+ content=thought_buffer,
136
+ metadata={
137
+ "title": "🧠 Thinking...",
138
+ "status": "pending"
139
+ }
140
+ )
141
+ return thought_buffer
142
+
143
+ def finalize_thinking(updated_history: List[UserMessage], thought_buffer: str):
144
+ thinking_duration = time.time() - start_time
145
+ updated_history[-1] = UserMessage(
146
+ role="assistant",
147
+ content=thought_buffer,
148
+ metadata={
149
+ "title": "🧠 Thinking Process",
150
+ "status": "done",
151
+ "duration": round(thinking_duration, 2)
152
+ }
153
+ )
154
+ updated_history.append(UserMessage(role="assistant", content=""))
155
+
156
+ def reset_chat() -> List[UserMessage]:
157
+ return []
158
+
159
+ style_sheet = """
160
+ .message-user {
161
+ background-color: #F3F4F6 !important;
162
+ border-radius: 10px;
163
+ padding: 10px;
164
+ margin: 8px 0;
165
+ }
166
+
167
+ .message-assistant {
168
+ background-color: #F9FAFB !important;
169
+ border-radius: 10px;
170
+ padding: 10px;
171
+ margin: 8px 0;
172
+ border-left: 3px solid #2D7FF9;
173
+ }
174
+
175
+ .thinking-box {
176
+ background-color: #F0F9FF !important;
177
+ border: 1px solid #BAE6FD;
178
+ border-radius: 6px;
179
+ }
180
+
181
+ .chat-container {
182
+ height: calc(100vh - 230px);
183
+ overflow-y: auto;
184
+ padding: 16px;
185
+ }
186
+
187
+ .input-container {
188
+ position: sticky;
189
+ bottom: 0;
190
+ background-color: #FFFFFF;
191
+ padding: 16px;
192
+ border-top: 1px solid #E5E7EB;
193
+ }
194
+
195
+ @media (max-width: 640px) {
196
+ .chat-container {
197
+ height: calc(100vh - 200px);
198
+ }
199
+ }
200
+
201
+ footer {
202
+ display: none !important;
203
+ }
204
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ with gui.Blocks(theme=custom_theme, css=style_sheet) as demo_interface:
207
+ gui.HTML("""
208
+ <div style="text-align: center; margin-bottom: 1rem">
209
+ <h1 style="font-size: 2.5rem; font-weight: 600; color: #1F2937">SmolLM2 Chat</h1>
210
+ <p style="font-size: 1.1rem; color: #6B7280">
211
+ Chat with SmolLM2-135M-Instruct: A small but capable AI assistant
212
+ </p>
213
+ </div>
214
+ """)
215
+
216
+ chat_interface = gui.Chatbot(
217
+ value=[],
218
+ avatar_images=(None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot.png"),
219
+ show_label=False,
220
+ container=True,
221
+ height=600,
222
+ elem_classes="chat-container",
223
+ type="messages"
224
+ )
225
+
226
+ with gui.Row(elem_classes="input-container"):
227
+ with gui.Column(scale=20):
228
+ message_input = gui.Textbox(
229
+ show_label=False,
230
+ placeholder="Type your message here...",
231
+ container=False,
232
+ lines=2
233
+ )
234
+
235
+ with gui.Column(scale=1, min_width=50):
236
+ send_button = gui.Button("Send", variant="primary")
237
+
238
+ with gui.Row():
239
+ clear_button = gui.Button("Clear Chat", variant="secondary")
240
+
241
+ message_input.submit(
242
+ stream_chat_response,
243
+ [message_input, chat_interface],
244
+ [chat_interface],
245
+ queue=True
246
+ ).then(
247
+ lambda: "",
248
+ None,
249
+ [message_input],
250
+ queue=False
251
+ )
252
+
253
+ send_button.click(
254
+ stream_chat_response,
255
+ [message_input, chat_interface],
256
+ [chat_interface],
257
+ queue=True
258
+ ).then(
259
+ lambda: "",
260
+ None,
261
+ [message_input],
262
+ queue=False
263
+ )
264
+
265
+ clear_button.click(
266
+ reset_chat,
267
+ None,
268
+ [chat_interface],
269
+ queue=False
270
+ )
271
+
272
+ message_input.submit(lambda: "", None, [message_input])
273
 
274
  if __name__ == "__main__":
275
+ demo_interface.launch(
276
+ server_name="0.0.0.0",
277
+ server_port=5000,
278
+ share=False
279
+ )