Ocks commited on
Commit
eb19008
Β·
verified Β·
1 Parent(s): cac97c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +364 -387
app.py CHANGED
@@ -1,13 +1,12 @@
1
  #!/usr/bin/env python3
2
 
3
- import tkinter as tk
4
- from tkinter import ttk, scrolledtext, messagebox
5
- import threading
6
- import queue
7
  import os
8
- from datetime import datetime
9
- from typing import List, Dict, Generator
10
  import warnings
 
 
 
 
 
11
  warnings.filterwarnings("ignore")
12
 
13
  # Try to import required libraries
@@ -16,452 +15,430 @@ try:
16
  from transformers import (
17
  AutoModelForCausalLM,
18
  AutoTokenizer,
19
- TextIteratorStreamer,
20
- pipeline
21
  )
22
  TRANSFORMERS_AVAILABLE = True
23
  except ImportError:
24
  TRANSFORMERS_AVAILABLE = False
25
 
26
- class CPULLMChatApp:
27
- def __init__(self, root):
28
- self.root = root
29
- self.root.title("CPU LLM Chat Application")
30
- self.root.geometry("1000x700")
31
-
32
- # Chat history
33
- self.chat_history: List[Dict[str, str]] = []
34
-
35
- # Model variables
36
- self.model = None
37
- self.tokenizer = None
38
- self.generator = None
 
 
 
 
 
 
39
  self.model_loaded = False
40
 
41
- # Threading
42
- self.generation_thread = None
43
- self.stop_generation = False
44
- self.response_queue = queue.Queue()
45
-
46
  # Configuration
47
  self.max_input_length = 2048
48
- self.max_new_tokens = tk.IntVar(value=256) # Reduced for CPU
49
- self.temperature = tk.DoubleVar(value=0.7)
50
- self.top_p = tk.DoubleVar(value=0.9)
51
- self.top_k = tk.IntVar(value=50)
52
- self.repetition_penalty = tk.DoubleVar(value=1.1)
53
-
54
- self.setup_ui()
55
- self.check_dependencies()
56
-
57
- def setup_ui(self):
58
- # Create main frame
59
- main_frame = ttk.Frame(self.root, padding="10")
60
- main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
61
-
62
- # Configure grid weights
63
- self.root.columnconfigure(0, weight=1)
64
- self.root.rowconfigure(0, weight=1)
65
- main_frame.columnconfigure(0, weight=1)
66
- main_frame.rowconfigure(1, weight=1)
67
-
68
- # Title and model selection
69
- title_frame = ttk.Frame(main_frame)
70
- title_frame.grid(row=0, column=0, sticky=(tk.W, tk.E), pady=(0, 10))
71
- title_frame.columnconfigure(1, weight=1)
72
-
73
- ttk.Label(title_frame, text="CPU LLM Chat", font=("Arial", 16, "bold")).grid(row=0, column=0, sticky=tk.W)
74
-
75
- # Model selection
76
- ttk.Label(title_frame, text="Model:").grid(row=0, column=2, padx=(20, 5))
77
- self.model_var = tk.StringVar(value="microsoft/DialoGPT-medium")
78
- model_combo = ttk.Combobox(title_frame, textvariable=self.model_var, width=30)
79
- model_combo['values'] = [
80
- "microsoft/DialoGPT-medium",
81
- "microsoft/DialoGPT-small",
82
- "distilgpt2",
83
- "gpt2",
84
- "facebook/blenderbot-400M-distill"
85
- ]
86
- model_combo.grid(row=0, column=3, padx=(0, 10))
87
-
88
- self.load_model_btn = ttk.Button(title_frame, text="Load Model", command=self.load_model)
89
- self.load_model_btn.grid(row=0, column=4)
90
-
91
- # Chat area
92
- chat_frame = ttk.Frame(main_frame)
93
- chat_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(0, 10))
94
- chat_frame.columnconfigure(0, weight=1)
95
- chat_frame.rowconfigure(0, weight=1)
96
-
97
- # Chat history display
98
- self.chat_display = scrolledtext.ScrolledText(
99
- chat_frame,
100
- wrap=tk.WORD,
101
- state=tk.DISABLED,
102
- font=("Arial", 10)
103
- )
104
- self.chat_display.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
105
-
106
- # Configure tags for styling
107
- self.chat_display.tag_configure("user", foreground="blue", font=("Arial", 10, "bold"))
108
- self.chat_display.tag_configure("assistant", foreground="green", font=("Arial", 10))
109
- self.chat_display.tag_configure("system", foreground="gray", font=("Arial", 9, "italic"))
110
-
111
- # Input area
112
- input_frame = ttk.Frame(main_frame)
113
- input_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=(0, 10))
114
- input_frame.columnconfigure(0, weight=1)
115
-
116
- # Input text
117
- self.input_text = scrolledtext.ScrolledText(input_frame, height=3, wrap=tk.WORD)
118
- self.input_text.grid(row=0, column=0, sticky=(tk.W, tk.E), padx=(0, 10))
119
- self.input_text.bind("<Control-Return>", lambda e: self.send_message())
120
-
121
- # Send button
122
- button_frame = ttk.Frame(input_frame)
123
- button_frame.grid(row=0, column=1, sticky=(tk.N, tk.S))
124
-
125
- self.send_btn = ttk.Button(button_frame, text="Send", command=self.send_message)
126
- self.send_btn.pack(pady=(0, 5))
127
-
128
- self.stop_btn = ttk.Button(button_frame, text="Stop", command=self.stop_generation_func, state=tk.DISABLED)
129
- self.stop_btn.pack(pady=(0, 5))
130
-
131
- self.clear_btn = ttk.Button(button_frame, text="Clear", command=self.clear_chat)
132
- self.clear_btn.pack()
133
 
134
- # Parameters panel
135
- params_frame = ttk.LabelFrame(main_frame, text="Generation Parameters", padding="5")
136
- params_frame.grid(row=3, column=0, sticky=(tk.W, tk.E), pady=(0, 10))
137
- params_frame.columnconfigure(1, weight=1)
138
- params_frame.columnconfigure(3, weight=1)
139
-
140
- # Max tokens
141
- ttk.Label(params_frame, text="Max Tokens:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
142
- ttk.Scale(params_frame, from_=50, to=512, variable=self.max_new_tokens, orient=tk.HORIZONTAL).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 10))
143
- ttk.Label(params_frame, textvariable=self.max_new_tokens).grid(row=0, column=2, padx=(0, 20))
144
-
145
- # Temperature
146
- ttk.Label(params_frame, text="Temperature:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5))
147
- ttk.Scale(params_frame, from_=0.1, to=2.0, variable=self.temperature, orient=tk.HORIZONTAL).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 10))
148
- temp_label = ttk.Label(params_frame, text="")
149
- temp_label.grid(row=1, column=2, padx=(0, 20))
150
-
151
- # Top-p
152
- ttk.Label(params_frame, text="Top-p:").grid(row=0, column=3, sticky=tk.W, padx=(0, 5))
153
- ttk.Scale(params_frame, from_=0.1, to=1.0, variable=self.top_p, orient=tk.HORIZONTAL).grid(row=0, column=4, sticky=(tk.W, tk.E), padx=(0, 10))
154
- top_p_label = ttk.Label(params_frame, text="")
155
- top_p_label.grid(row=0, column=5)
156
-
157
- # Top-k
158
- ttk.Label(params_frame, text="Top-k:").grid(row=1, column=3, sticky=tk.W, padx=(0, 5))
159
- ttk.Scale(params_frame, from_=1, to=100, variable=self.top_k, orient=tk.HORIZONTAL).grid(row=1, column=4, sticky=(tk.W, tk.E), padx=(0, 10))
160
- ttk.Label(params_frame, textvariable=self.top_k).grid(row=1, column=5)
161
-
162
- # Update parameter labels
163
- def update_temp_label(*args):
164
- temp_label.config(text=f"{self.temperature.get():.2f}")
165
- def update_top_p_label(*args):
166
- top_p_label.config(text=f"{self.top_p.get():.2f}")
167
-
168
- self.temperature.trace('w', update_temp_label)
169
- self.top_p.trace('w', update_top_p_label)
170
- update_temp_label()
171
- update_top_p_label()
172
-
173
- # Status bar
174
- self.status_var = tk.StringVar(value="Ready - Please load a model first")
175
- status_bar = ttk.Label(main_frame, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W)
176
- status_bar.grid(row=4, column=0, sticky=(tk.W, tk.E))
177
-
178
- # Add example messages
179
- examples_frame = ttk.LabelFrame(main_frame, text="Example Messages", padding="5")
180
- examples_frame.grid(row=5, column=0, sticky=(tk.W, tk.E), pady=(10, 0))
181
-
182
- examples = [
183
- "Hello! How are you today?",
184
- "Tell me a short joke.",
185
- "What's the weather like?",
186
- "Explain quantum computing in simple terms."
187
- ]
188
-
189
- for i, example in enumerate(examples):
190
- btn = ttk.Button(examples_frame, text=example,
191
- command=lambda e=example: self.set_input_text(e))
192
- btn.grid(row=i//2, column=i%2, sticky=(tk.W, tk.E), padx=5, pady=2)
193
-
194
- examples_frame.columnconfigure(0, weight=1)
195
- examples_frame.columnconfigure(1, weight=1)
196
-
197
- def check_dependencies(self):
198
  if not TRANSFORMERS_AVAILABLE:
199
- self.add_system_message("❌ Transformers library not found. Please install: pip install torch transformers")
200
- self.send_btn.config(state=tk.DISABLED)
201
- self.load_model_btn.config(state=tk.DISABLED)
202
- else:
203
- self.add_system_message("βœ… Dependencies loaded. Please select and load a model.")
204
-
205
- def set_input_text(self, text):
206
- self.input_text.delete("1.0", tk.END)
207
- self.input_text.insert("1.0", text)
208
- self.input_text.focus()
209
-
210
- def add_system_message(self, message):
211
- self.chat_display.config(state=tk.NORMAL)
212
- self.chat_display.insert(tk.END, f"[{datetime.now().strftime('%H:%M:%S')}] {message}\n", "system")
213
- self.chat_display.config(state=tk.DISABLED)
214
- self.chat_display.see(tk.END)
215
-
216
- def add_user_message(self, message):
217
- self.chat_display.config(state=tk.NORMAL)
218
- self.chat_display.insert(tk.END, f"\nπŸ‘€ You: ", "user")
219
- self.chat_display.insert(tk.END, f"{message}\n", "user")
220
- self.chat_display.config(state=tk.DISABLED)
221
- self.chat_display.see(tk.END)
222
-
223
- def add_assistant_message(self, message):
224
- self.chat_display.config(state=tk.NORMAL)
225
- self.chat_display.insert(tk.END, f"πŸ€– Assistant: ", "assistant")
226
- self.chat_display.insert(tk.END, f"{message}\n", "assistant")
227
- self.chat_display.config(state=tk.DISABLED)
228
- self.chat_display.see(tk.END)
229
-
230
- def update_assistant_message(self, additional_text):
231
- self.chat_display.config(state=tk.NORMAL)
232
- self.chat_display.insert(tk.END, additional_text, "assistant")
233
- self.chat_display.config(state=tk.DISABLED)
234
- self.chat_display.see(tk.END)
235
-
236
- def load_model(self):
237
- if not TRANSFORMERS_AVAILABLE:
238
- messagebox.showerror("Error", "Transformers library not available")
239
- return
240
 
241
- model_name = self.model_var.get()
242
- if not model_name:
243
- messagebox.showwarning("Warning", "Please select a model")
244
- return
245
 
246
- # Disable buttons during loading
247
- self.load_model_btn.config(state=tk.DISABLED)
248
- self.send_btn.config(state=tk.DISABLED)
249
- self.status_var.set(f"Loading model: {model_name}...")
250
-
251
- # Load model in separate thread
252
- thread = threading.Thread(target=self._load_model_thread, args=(model_name,))
253
- thread.daemon = True
254
- thread.start()
255
-
256
- def _load_model_thread(self, model_name):
257
  try:
258
- self.add_system_message(f"Loading model: {model_name}")
259
-
260
- # Force CPU usage and optimize for CPU
261
- device = "cpu"
262
- torch_dtype = torch.float32 # Use float32 for CPU
263
 
264
  # Load tokenizer
265
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
266
- if self.tokenizer.pad_token is None:
267
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
268
 
269
  # Load model with CPU optimizations
270
- self.model = AutoModelForCausalLM.from_pretrained(
271
  model_name,
272
- torch_dtype=torch_dtype,
273
- device_map={"": device},
274
  low_cpu_mem_usage=True
275
  )
276
 
277
- # Set model to evaluation mode
278
- self.model.eval()
279
 
 
280
  self.model_loaded = True
281
 
282
- # Update UI on main thread
283
- self.root.after(0, self._model_loaded_callback, model_name)
 
284
 
285
  except Exception as e:
286
- error_msg = f"Failed to load model: {str(e)}"
287
- self.root.after(0, self._model_load_error_callback, error_msg)
288
-
289
- def _model_loaded_callback(self, model_name):
290
- self.add_system_message(f"βœ… Model loaded successfully: {model_name}")
291
- self.status_var.set(f"Model loaded: {model_name}")
292
- self.load_model_btn.config(state=tk.NORMAL)
293
- self.send_btn.config(state=tk.NORMAL)
294
-
295
- def _model_load_error_callback(self, error_msg):
296
- self.add_system_message(f"❌ {error_msg}")
297
- self.status_var.set("Model loading failed")
298
- self.load_model_btn.config(state=tk.NORMAL)
299
- messagebox.showerror("Model Loading Error", error_msg)
300
 
301
- def send_message(self):
 
 
 
 
 
 
 
 
 
 
 
302
  if not self.model_loaded:
303
- messagebox.showwarning("Warning", "Please load a model first")
304
  return
305
 
306
- message = self.input_text.get("1.0", tk.END).strip()
307
- if not message:
308
  return
309
 
310
- # Add user message to chat
311
- self.add_user_message(message)
312
- self.input_text.delete("1.0", tk.END)
313
-
314
- # Disable send button and enable stop button
315
- self.send_btn.config(state=tk.DISABLED)
316
- self.stop_btn.config(state=tk.NORMAL)
317
- self.stop_generation = False
318
-
319
- # Add to chat history
320
- self.chat_history.append({"role": "user", "content": message})
321
-
322
- # Start generation thread
323
- self.generation_thread = threading.Thread(target=self._generate_response, args=(message,))
324
- self.generation_thread.daemon = True
325
- self.generation_thread.start()
326
-
327
- # Start checking for responses
328
- self.check_response_queue()
329
-
330
- def _generate_response(self, message):
331
  try:
332
- self.status_var.set("Generating response...")
 
333
 
334
- # Prepare input
335
- if "DialoGPT" in self.model_var.get():
336
- # For DialoGPT, use conversation history
 
 
337
  chat_history_ids = None
338
- for turn in self.chat_history[-5:]: # Use last 5 turns
339
- new_user_input_ids = self.tokenizer.encode(
340
- turn["content"] + self.tokenizer.eos_token,
341
- return_tensors='pt'
342
- )
343
-
344
- if chat_history_ids is not None:
345
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
346
- else:
347
- bot_input_ids = new_user_input_ids
 
 
348
 
349
- chat_history_ids = bot_input_ids
 
 
 
 
 
 
 
 
350
 
351
- input_ids = chat_history_ids
 
 
 
 
 
 
 
 
 
 
352
  else:
353
- # For other models, use simple encoding
354
- input_ids = self.tokenizer.encode(message, return_tensors='pt')
 
 
 
 
 
355
 
356
  # Limit input length
357
  if input_ids.shape[1] > self.max_input_length:
358
  input_ids = input_ids[:, -self.max_input_length:]
359
 
360
- # Generation parameters
 
 
 
 
 
 
 
361
  generation_kwargs = {
362
  'input_ids': input_ids,
363
- 'max_new_tokens': self.max_new_tokens.get(),
364
- 'temperature': self.temperature.get(),
365
- 'top_p': self.top_p.get(),
366
- 'top_k': self.top_k.get(),
367
- 'repetition_penalty': self.repetition_penalty.get(),
 
368
  'do_sample': True,
369
- 'pad_token_id': self.tokenizer.pad_token_id,
370
- 'eos_token_id': self.tokenizer.eos_token_id,
371
  'no_repeat_ngram_size': 2,
372
  }
373
 
374
- # Create streamer for real-time output
375
- streamer = TextIteratorStreamer(
376
- self.tokenizer,
377
- skip_prompt=True,
378
- skip_special_tokens=True,
379
- timeout=30.0
380
- )
381
- generation_kwargs['streamer'] = streamer
382
-
383
- # Start generation in a separate thread
384
- generation_thread = threading.Thread(
385
- target=self.model.generate,
386
  kwargs=generation_kwargs
387
  )
388
  generation_thread.start()
389
 
390
  # Stream the response
391
- self.response_queue.put(("start", ""))
392
-
393
- generated_text = ""
394
  for new_text in streamer:
395
- if self.stop_generation:
396
- break
397
- generated_text += new_text
398
- self.response_queue.put(("update", new_text))
399
 
400
- if not self.stop_generation:
401
- # Add to chat history
402
- self.chat_history.append({"role": "assistant", "content": generated_text})
403
- self.response_queue.put(("complete", generated_text))
404
- else:
405
- self.response_queue.put(("stopped", ""))
406
-
407
  except Exception as e:
408
- self.response_queue.put(("error", str(e)))
 
 
 
409
 
410
- def check_response_queue(self):
411
- try:
412
- while True:
413
- action, data = self.response_queue.get_nowait()
414
-
415
- if action == "start":
416
- self.add_assistant_message("")
417
- elif action == "update":
418
- self.update_assistant_message(data)
419
- elif action == "complete":
420
- self.status_var.set("Response complete")
421
- self.send_btn.config(state=tk.NORMAL)
422
- self.stop_btn.config(state=tk.DISABLED)
423
- return
424
- elif action == "stopped":
425
- self.update_assistant_message(" [Generation stopped]")
426
- self.status_var.set("Generation stopped")
427
- self.send_btn.config(state=tk.NORMAL)
428
- self.stop_btn.config(state=tk.DISABLED)
429
- return
430
- elif action == "error":
431
- self.add_system_message(f"❌ Generation error: {data}")
432
- self.status_var.set("Generation failed")
433
- self.send_btn.config(state=tk.NORMAL)
434
- self.stop_btn.config(state=tk.DISABLED)
435
- return
436
-
437
- except queue.Empty:
438
- pass
439
-
440
- # Schedule next check
441
- self.root.after(100, self.check_response_queue)
442
 
443
- def stop_generation_func(self):
444
- self.stop_generation = True
445
- self.status_var.set("Stopping generation...")
446
 
447
- def clear_chat(self):
448
- self.chat_history = []
449
- self.chat_display.config(state=tk.NORMAL)
450
- self.chat_display.delete("1.0", tk.END)
451
- self.chat_display.config(state=tk.DISABLED)
452
- self.add_system_message("Chat cleared")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  def main():
455
- root = tk.Tk()
456
- app = CPULLMChatApp(root)
 
 
 
 
 
 
457
 
458
- # Center the window
459
- root.update_idletasks()
460
- x = (root.winfo_screenwidth() - root.winfo_width()) // 2
461
- y = (root.winfo_screenheight() - root.winfo_height()) // 2
462
- root.geometry(f"+{x}+{y}")
463
 
464
- root.mainloop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  if __name__ == "__main__":
467
  main()
 
1
  #!/usr/bin/env python3
2
 
 
 
 
 
3
  import os
 
 
4
  import warnings
5
+ from collections.abc import Iterator
6
+ from threading import Thread
7
+ from typing import List, Dict, Optional, Tuple
8
+ import time
9
+
10
  warnings.filterwarnings("ignore")
11
 
12
  # Try to import required libraries
 
15
  from transformers import (
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
18
+ TextIteratorStreamer
 
19
  )
20
  TRANSFORMERS_AVAILABLE = True
21
  except ImportError:
22
  TRANSFORMERS_AVAILABLE = False
23
 
24
+ try:
25
+ import gradio as gr
26
+ GRADIO_AVAILABLE = True
27
+ except ImportError:
28
+ GRADIO_AVAILABLE = False
29
+
30
+ class CPULLMChat:
31
+ def __init__(self):
32
+ self.models = {
33
+ "microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)",
34
+ "microsoft/DialoGPT-small": "DialoGPT Small (Faster)",
35
+ "distilgpt2": "DistilGPT2 (Very fast)",
36
+ "gpt2": "GPT2 (Standard)",
37
+ "facebook/blenderbot-400M-distill": "BlenderBot (Conversational)"
38
+ }
39
+
40
+ self.current_model = None
41
+ self.current_tokenizer = None
42
+ self.current_model_name = None
43
  self.model_loaded = False
44
 
 
 
 
 
 
45
  # Configuration
46
  self.max_input_length = 2048
47
+ self.device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ def load_model(self, model_name: str, progress=gr.Progress()) -> str:
50
+ """Load the selected model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not TRANSFORMERS_AVAILABLE:
52
+ return "❌ Error: transformers library not installed. Run: pip install torch transformers"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ if model_name == self.current_model_name and self.model_loaded:
55
+ return f"βœ… Model {model_name} is already loaded!"
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
+ progress(0.1, desc="Loading tokenizer...")
 
 
 
 
59
 
60
  # Load tokenizer
61
+ self.current_tokenizer = AutoTokenizer.from_pretrained(
62
+ model_name,
63
+ padding_side="left"
64
+ )
65
+ if self.current_tokenizer.pad_token is None:
66
+ self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
67
+
68
+ progress(0.5, desc="Loading model...")
69
 
70
  # Load model with CPU optimizations
71
+ self.current_model = AutoModelForCausalLM.from_pretrained(
72
  model_name,
73
+ torch_dtype=torch.float32, # Use float32 for CPU
74
+ device_map={"": self.device},
75
  low_cpu_mem_usage=True
76
  )
77
 
78
+ # Set to evaluation mode
79
+ self.current_model.eval()
80
 
81
+ self.current_model_name = model_name
82
  self.model_loaded = True
83
 
84
+ progress(1.0, desc="Model loaded successfully!")
85
+
86
+ return f"βœ… Successfully loaded: {model_name}"
87
 
88
  except Exception as e:
89
+ self.model_loaded = False
90
+ return f"❌ Failed to load model {model_name}: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ def generate_response(
93
+ self,
94
+ message: str,
95
+ chat_history: List[List[str]],
96
+ max_new_tokens: int = 256,
97
+ temperature: float = 0.7,
98
+ top_p: float = 0.9,
99
+ top_k: int = 50,
100
+ repetition_penalty: float = 1.1,
101
+ ) -> Iterator[str]:
102
+ """Generate response with streaming"""
103
+
104
  if not self.model_loaded:
105
+ yield "❌ Please load a model first!"
106
  return
107
 
108
+ if not message.strip():
109
+ yield "Please enter a message."
110
  return
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  try:
113
+ # Prepare conversation context
114
+ conversation_text = ""
115
 
116
+ # Add chat history (last 5 exchanges to manage memory)
117
+ recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history
118
+
119
+ if "DialoGPT" in self.current_model_name:
120
+ # For DialoGPT, format as conversation
121
  chat_history_ids = None
122
+
123
+ # Build conversation from history
124
+ for user_msg, bot_msg in recent_history:
125
+ if user_msg:
126
+ user_input_ids = self.current_tokenizer.encode(
127
+ user_msg + self.current_tokenizer.eos_token,
128
+ return_tensors='pt'
129
+ )
130
+ if chat_history_ids is not None:
131
+ chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1)
132
+ else:
133
+ chat_history_ids = user_input_ids
134
 
135
+ if bot_msg:
136
+ bot_input_ids = self.current_tokenizer.encode(
137
+ bot_msg + self.current_tokenizer.eos_token,
138
+ return_tensors='pt'
139
+ )
140
+ if chat_history_ids is not None:
141
+ chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1)
142
+ else:
143
+ chat_history_ids = bot_input_ids
144
 
145
+ # Add current message
146
+ new_user_input_ids = self.current_tokenizer.encode(
147
+ message + self.current_tokenizer.eos_token,
148
+ return_tensors='pt'
149
+ )
150
+
151
+ if chat_history_ids is not None:
152
+ input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
153
+ else:
154
+ input_ids = new_user_input_ids
155
+
156
  else:
157
+ # For other models, create context from history
158
+ for user_msg, bot_msg in recent_history:
159
+ if user_msg and bot_msg:
160
+ conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
161
+
162
+ conversation_text += f"User: {message}\nAssistant:"
163
+ input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt')
164
 
165
  # Limit input length
166
  if input_ids.shape[1] > self.max_input_length:
167
  input_ids = input_ids[:, -self.max_input_length:]
168
 
169
+ # Set up streaming
170
+ streamer = TextIteratorStreamer(
171
+ self.current_tokenizer,
172
+ timeout=60.0,
173
+ skip_prompt=True,
174
+ skip_special_tokens=True
175
+ )
176
+
177
  generation_kwargs = {
178
  'input_ids': input_ids,
179
+ 'streamer': streamer,
180
+ 'max_new_tokens': max_new_tokens,
181
+ 'temperature': temperature,
182
+ 'top_p': top_p,
183
+ 'top_k': top_k,
184
+ 'repetition_penalty': repetition_penalty,
185
  'do_sample': True,
186
+ 'pad_token_id': self.current_tokenizer.pad_token_id,
187
+ 'eos_token_id': self.current_tokenizer.eos_token_id,
188
  'no_repeat_ngram_size': 2,
189
  }
190
 
191
+ # Start generation in separate thread
192
+ generation_thread = Thread(
193
+ target=self.current_model.generate,
 
 
 
 
 
 
 
 
 
194
  kwargs=generation_kwargs
195
  )
196
  generation_thread.start()
197
 
198
  # Stream the response
199
+ partial_response = ""
 
 
200
  for new_text in streamer:
201
+ partial_response += new_text
202
+ yield partial_response
 
 
203
 
 
 
 
 
 
 
 
204
  except Exception as e:
205
+ yield f"❌ Generation error: {str(e)}"
206
+
207
+ def create_interface():
208
+ """Create the Gradio interface"""
209
 
210
+ if not GRADIO_AVAILABLE:
211
+ print("❌ Error: gradio library not installed. Run: pip install gradio")
212
+ return None
213
+
214
+ if not TRANSFORMERS_AVAILABLE:
215
+ print("❌ Error: transformers library not installed. Run: pip install torch transformers")
216
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # Initialize the chat system
219
+ chat_system = CPULLMChat()
 
220
 
221
+ # Custom CSS for better styling
222
+ css = """
223
+ .gradio-container {
224
+ max-width: 1200px;
225
+ margin: auto;
226
+ }
227
+ .chat-message {
228
+ padding: 10px;
229
+ margin: 5px 0;
230
+ border-radius: 10px;
231
+ }
232
+ .user-message {
233
+ background-color: #e3f2fd;
234
+ margin-left: 20%;
235
+ }
236
+ .bot-message {
237
+ background-color: #f1f8e9;
238
+ margin-right: 20%;
239
+ }
240
+ """
241
+
242
+ with gr.Blocks(css=css, title="CPU LLM Chat") as demo:
243
+ gr.Markdown("# πŸ€– CPU-Optimized LLM Chat")
244
+ gr.Markdown("*A lightweight chat interface for running language models on CPU*")
245
+
246
+ with gr.Row():
247
+ with gr.Column(scale=2):
248
+ model_dropdown = gr.Dropdown(
249
+ choices=list(chat_system.models.keys()),
250
+ value="microsoft/DialoGPT-medium",
251
+ label="Select Model",
252
+ info="Choose a model to load. DialoGPT models work best for chat."
253
+ )
254
+ load_btn = gr.Button("πŸ”„ Load Model", variant="primary")
255
+ model_status = gr.Textbox(
256
+ label="Model Status",
257
+ value="No model loaded",
258
+ interactive=False
259
+ )
260
+
261
+ with gr.Column(scale=1):
262
+ gr.Markdown("### πŸ’‘ Model Info")
263
+ gr.Markdown("""
264
+ - **DialoGPT Medium**: Best quality, slower
265
+ - **DialoGPT Small**: Good balance
266
+ - **DistilGPT2**: Fastest option
267
+ - **GPT2**: General purpose
268
+ - **BlenderBot**: Conversational AI
269
+ """)
270
+
271
+ # Chat interface
272
+ chatbot = gr.Chatbot(
273
+ label="Chat History",
274
+ height=400,
275
+ show_label=True,
276
+ container=True
277
+ )
278
+
279
+ with gr.Row():
280
+ msg = gr.Textbox(
281
+ label="Your Message",
282
+ placeholder="Type your message here... (Press Ctrl+Enter to send)",
283
+ lines=3,
284
+ max_lines=10,
285
+ show_label=False
286
+ )
287
+ send_btn = gr.Button("πŸ“€ Send", variant="primary")
288
+
289
+ # Parameters section
290
+ with gr.Accordion("βš™οΈ Generation Parameters", open=False):
291
+ with gr.Row():
292
+ max_tokens = gr.Slider(
293
+ minimum=50,
294
+ maximum=512,
295
+ value=256,
296
+ step=10,
297
+ label="Max New Tokens",
298
+ info="Maximum number of tokens to generate"
299
+ )
300
+ temperature = gr.Slider(
301
+ minimum=0.1,
302
+ maximum=2.0,
303
+ value=0.7,
304
+ step=0.1,
305
+ label="Temperature",
306
+ info="Higher values = more creative, lower = more focused"
307
+ )
308
+
309
+ with gr.Row():
310
+ top_p = gr.Slider(
311
+ minimum=0.1,
312
+ maximum=1.0,
313
+ value=0.9,
314
+ step=0.05,
315
+ label="Top-p",
316
+ info="Nucleus sampling parameter"
317
+ )
318
+ top_k = gr.Slider(
319
+ minimum=1,
320
+ maximum=100,
321
+ value=50,
322
+ step=1,
323
+ label="Top-k",
324
+ info="Top-k sampling parameter"
325
+ )
326
+ repetition_penalty = gr.Slider(
327
+ minimum=1.0,
328
+ maximum=2.0,
329
+ value=1.1,
330
+ step=0.05,
331
+ label="Repetition Penalty",
332
+ info="Penalty for repeating tokens"
333
+ )
334
+
335
+ # Example messages
336
+ with gr.Accordion("πŸ’¬ Example Messages", open=False):
337
+ examples = [
338
+ "Hello! How are you today?",
339
+ "Tell me a short story about a robot.",
340
+ "What's the difference between AI and machine learning?",
341
+ "Can you help me write a poem about nature?",
342
+ "Explain quantum computing in simple terms.",
343
+ ]
344
+
345
+ example_buttons = []
346
+ for example in examples:
347
+ btn = gr.Button(example, variant="secondary")
348
+ example_buttons.append(btn)
349
+
350
+ # Clear chat button
351
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
352
+
353
+ # Event handlers
354
+ def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
355
+ if not chat_system.model_loaded:
356
+ history.append([message, "❌ Please load a model first!"])
357
+ return history, ""
358
+
359
+ history.append([message, ""])
360
+
361
+ for partial_response in chat_system.generate_response(
362
+ message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty
363
+ ):
364
+ history[-1][1] = partial_response
365
+ yield history, ""
366
+
367
+ def load_model_handler(model_name, progress=gr.Progress()):
368
+ return chat_system.load_model(model_name, progress)
369
+
370
+ def set_example(example_text):
371
+ return example_text
372
+
373
+ def clear_chat():
374
+ return [], ""
375
+
376
+ # Wire up events
377
+ load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status])
378
+
379
+ msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg])
380
+ send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg])
381
+
382
+ clear_btn.click(clear_chat, outputs=[chatbot, msg])
383
+
384
+ # Example buttons
385
+ for btn, example in zip(example_buttons, examples):
386
+ btn.click(set_example, inputs=[gr.State(example)], outputs=[msg])
387
+
388
+ # Footer
389
+ gr.Markdown("""
390
+ ---
391
+ ### πŸ“‹ Instructions:
392
+ 1. **Select and load a model** using the dropdown and "Load Model" button
393
+ 2. **Wait for the model to load** (may take 1-2 minutes on first load)
394
+ 3. **Start chatting** once you see "βœ… Successfully loaded" message
395
+ 4. **Adjust parameters** if needed for different response styles
396
+
397
+ ### πŸ’» System Requirements:
398
+ - CPU with at least 4GB RAM available
399
+ - Python 3.8+ with torch and transformers installed
400
+
401
+ ### ⚑ Performance Tips:
402
+ - Use DialoGPT-small for fastest responses
403
+ - Keep max tokens under 300 for better speed
404
+ - Lower temperature (0.3-0.7) for more consistent responses
405
+ """)
406
+
407
+ return demo
408
 
409
  def main():
410
+ """Main function to run the application"""
411
+
412
+ print("===== CPU LLM Chat Application =====")
413
+ print("Checking dependencies...")
414
+
415
+ if not GRADIO_AVAILABLE:
416
+ print("❌ Gradio not found. Install with: pip install gradio")
417
+ return
418
 
419
+ if not TRANSFORMERS_AVAILABLE:
420
+ print("❌ Transformers not found. Install with: pip install torch transformers")
421
+ return
 
 
422
 
423
+ print("βœ… All dependencies found!")
424
+ print("Starting web interface...")
425
+
426
+ try:
427
+ demo = create_interface()
428
+ if demo:
429
+ # Launch with appropriate settings
430
+ demo.queue(max_size=10).launch(
431
+ server_name="0.0.0.0", # Allow external access
432
+ server_port=7860, # Default Gradio port
433
+ share=False, # Set to True if you want a public link
434
+ show_error=True,
435
+ show_tips=True,
436
+ inbrowser=False # Don't try to open browser in headless env
437
+ )
438
+ except KeyboardInterrupt:
439
+ print("\nπŸ‘‹ Application stopped by user")
440
+ except Exception as e:
441
+ print(f"❌ Error starting application: {e}")
442
 
443
  if __name__ == "__main__":
444
  main()