|
|
|
|
|
import tkinter as tk |
|
from tkinter import ttk, scrolledtext, messagebox |
|
import threading |
|
import queue |
|
import os |
|
from datetime import datetime |
|
from typing import List, Dict, Generator |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
try: |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
TextIteratorStreamer, |
|
pipeline |
|
) |
|
TRANSFORMERS_AVAILABLE = True |
|
except ImportError: |
|
TRANSFORMERS_AVAILABLE = False |
|
|
|
class CPULLMChatApp: |
|
def __init__(self, root): |
|
self.root = root |
|
self.root.title("CPU LLM Chat Application") |
|
self.root.geometry("1000x700") |
|
|
|
|
|
self.chat_history: List[Dict[str, str]] = [] |
|
|
|
|
|
self.model = None |
|
self.tokenizer = None |
|
self.generator = None |
|
self.model_loaded = False |
|
|
|
|
|
self.generation_thread = None |
|
self.stop_generation = False |
|
self.response_queue = queue.Queue() |
|
|
|
|
|
self.max_input_length = 2048 |
|
self.max_new_tokens = tk.IntVar(value=256) |
|
self.temperature = tk.DoubleVar(value=0.7) |
|
self.top_p = tk.DoubleVar(value=0.9) |
|
self.top_k = tk.IntVar(value=50) |
|
self.repetition_penalty = tk.DoubleVar(value=1.1) |
|
|
|
self.setup_ui() |
|
self.check_dependencies() |
|
|
|
def setup_ui(self): |
|
|
|
main_frame = ttk.Frame(self.root, padding="10") |
|
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) |
|
|
|
|
|
self.root.columnconfigure(0, weight=1) |
|
self.root.rowconfigure(0, weight=1) |
|
main_frame.columnconfigure(0, weight=1) |
|
main_frame.rowconfigure(1, weight=1) |
|
|
|
|
|
title_frame = ttk.Frame(main_frame) |
|
title_frame.grid(row=0, column=0, sticky=(tk.W, tk.E), pady=(0, 10)) |
|
title_frame.columnconfigure(1, weight=1) |
|
|
|
ttk.Label(title_frame, text="CPU LLM Chat", font=("Arial", 16, "bold")).grid(row=0, column=0, sticky=tk.W) |
|
|
|
|
|
ttk.Label(title_frame, text="Model:").grid(row=0, column=2, padx=(20, 5)) |
|
self.model_var = tk.StringVar(value="microsoft/DialoGPT-medium") |
|
model_combo = ttk.Combobox(title_frame, textvariable=self.model_var, width=30) |
|
model_combo['values'] = [ |
|
"microsoft/DialoGPT-medium", |
|
"microsoft/DialoGPT-small", |
|
"distilgpt2", |
|
"gpt2", |
|
"facebook/blenderbot-400M-distill" |
|
] |
|
model_combo.grid(row=0, column=3, padx=(0, 10)) |
|
|
|
self.load_model_btn = ttk.Button(title_frame, text="Load Model", command=self.load_model) |
|
self.load_model_btn.grid(row=0, column=4) |
|
|
|
|
|
chat_frame = ttk.Frame(main_frame) |
|
chat_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(0, 10)) |
|
chat_frame.columnconfigure(0, weight=1) |
|
chat_frame.rowconfigure(0, weight=1) |
|
|
|
|
|
self.chat_display = scrolledtext.ScrolledText( |
|
chat_frame, |
|
wrap=tk.WORD, |
|
state=tk.DISABLED, |
|
font=("Arial", 10) |
|
) |
|
self.chat_display.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) |
|
|
|
|
|
self.chat_display.tag_configure("user", foreground="blue", font=("Arial", 10, "bold")) |
|
self.chat_display.tag_configure("assistant", foreground="green", font=("Arial", 10)) |
|
self.chat_display.tag_configure("system", foreground="gray", font=("Arial", 9, "italic")) |
|
|
|
|
|
input_frame = ttk.Frame(main_frame) |
|
input_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=(0, 10)) |
|
input_frame.columnconfigure(0, weight=1) |
|
|
|
|
|
self.input_text = scrolledtext.ScrolledText(input_frame, height=3, wrap=tk.WORD) |
|
self.input_text.grid(row=0, column=0, sticky=(tk.W, tk.E), padx=(0, 10)) |
|
self.input_text.bind("<Control-Return>", lambda e: self.send_message()) |
|
|
|
|
|
button_frame = ttk.Frame(input_frame) |
|
button_frame.grid(row=0, column=1, sticky=(tk.N, tk.S)) |
|
|
|
self.send_btn = ttk.Button(button_frame, text="Send", command=self.send_message) |
|
self.send_btn.pack(pady=(0, 5)) |
|
|
|
self.stop_btn = ttk.Button(button_frame, text="Stop", command=self.stop_generation_func, state=tk.DISABLED) |
|
self.stop_btn.pack(pady=(0, 5)) |
|
|
|
self.clear_btn = ttk.Button(button_frame, text="Clear", command=self.clear_chat) |
|
self.clear_btn.pack() |
|
|
|
|
|
params_frame = ttk.LabelFrame(main_frame, text="Generation Parameters", padding="5") |
|
params_frame.grid(row=3, column=0, sticky=(tk.W, tk.E), pady=(0, 10)) |
|
params_frame.columnconfigure(1, weight=1) |
|
params_frame.columnconfigure(3, weight=1) |
|
|
|
|
|
ttk.Label(params_frame, text="Max Tokens:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) |
|
ttk.Scale(params_frame, from_=50, to=512, variable=self.max_new_tokens, orient=tk.HORIZONTAL).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 10)) |
|
ttk.Label(params_frame, textvariable=self.max_new_tokens).grid(row=0, column=2, padx=(0, 20)) |
|
|
|
|
|
ttk.Label(params_frame, text="Temperature:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) |
|
ttk.Scale(params_frame, from_=0.1, to=2.0, variable=self.temperature, orient=tk.HORIZONTAL).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 10)) |
|
temp_label = ttk.Label(params_frame, text="") |
|
temp_label.grid(row=1, column=2, padx=(0, 20)) |
|
|
|
|
|
ttk.Label(params_frame, text="Top-p:").grid(row=0, column=3, sticky=tk.W, padx=(0, 5)) |
|
ttk.Scale(params_frame, from_=0.1, to=1.0, variable=self.top_p, orient=tk.HORIZONTAL).grid(row=0, column=4, sticky=(tk.W, tk.E), padx=(0, 10)) |
|
top_p_label = ttk.Label(params_frame, text="") |
|
top_p_label.grid(row=0, column=5) |
|
|
|
|
|
ttk.Label(params_frame, text="Top-k:").grid(row=1, column=3, sticky=tk.W, padx=(0, 5)) |
|
ttk.Scale(params_frame, from_=1, to=100, variable=self.top_k, orient=tk.HORIZONTAL).grid(row=1, column=4, sticky=(tk.W, tk.E), padx=(0, 10)) |
|
ttk.Label(params_frame, textvariable=self.top_k).grid(row=1, column=5) |
|
|
|
|
|
def update_temp_label(*args): |
|
temp_label.config(text=f"{self.temperature.get():.2f}") |
|
def update_top_p_label(*args): |
|
top_p_label.config(text=f"{self.top_p.get():.2f}") |
|
|
|
self.temperature.trace('w', update_temp_label) |
|
self.top_p.trace('w', update_top_p_label) |
|
update_temp_label() |
|
update_top_p_label() |
|
|
|
|
|
self.status_var = tk.StringVar(value="Ready - Please load a model first") |
|
status_bar = ttk.Label(main_frame, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W) |
|
status_bar.grid(row=4, column=0, sticky=(tk.W, tk.E)) |
|
|
|
|
|
examples_frame = ttk.LabelFrame(main_frame, text="Example Messages", padding="5") |
|
examples_frame.grid(row=5, column=0, sticky=(tk.W, tk.E), pady=(10, 0)) |
|
|
|
examples = [ |
|
"Hello! How are you today?", |
|
"Tell me a short joke.", |
|
"What's the weather like?", |
|
"Explain quantum computing in simple terms." |
|
] |
|
|
|
for i, example in enumerate(examples): |
|
btn = ttk.Button(examples_frame, text=example, |
|
command=lambda e=example: self.set_input_text(e)) |
|
btn.grid(row=i//2, column=i%2, sticky=(tk.W, tk.E), padx=5, pady=2) |
|
|
|
examples_frame.columnconfigure(0, weight=1) |
|
examples_frame.columnconfigure(1, weight=1) |
|
|
|
def check_dependencies(self): |
|
if not TRANSFORMERS_AVAILABLE: |
|
self.add_system_message("❌ Transformers library not found. Please install: pip install torch transformers") |
|
self.send_btn.config(state=tk.DISABLED) |
|
self.load_model_btn.config(state=tk.DISABLED) |
|
else: |
|
self.add_system_message("✅ Dependencies loaded. Please select and load a model.") |
|
|
|
def set_input_text(self, text): |
|
self.input_text.delete("1.0", tk.END) |
|
self.input_text.insert("1.0", text) |
|
self.input_text.focus() |
|
|
|
def add_system_message(self, message): |
|
self.chat_display.config(state=tk.NORMAL) |
|
self.chat_display.insert(tk.END, f"[{datetime.now().strftime('%H:%M:%S')}] {message}\n", "system") |
|
self.chat_display.config(state=tk.DISABLED) |
|
self.chat_display.see(tk.END) |
|
|
|
def add_user_message(self, message): |
|
self.chat_display.config(state=tk.NORMAL) |
|
self.chat_display.insert(tk.END, f"\n👤 You: ", "user") |
|
self.chat_display.insert(tk.END, f"{message}\n", "user") |
|
self.chat_display.config(state=tk.DISABLED) |
|
self.chat_display.see(tk.END) |
|
|
|
def add_assistant_message(self, message): |
|
self.chat_display.config(state=tk.NORMAL) |
|
self.chat_display.insert(tk.END, f"🤖 Assistant: ", "assistant") |
|
self.chat_display.insert(tk.END, f"{message}\n", "assistant") |
|
self.chat_display.config(state=tk.DISABLED) |
|
self.chat_display.see(tk.END) |
|
|
|
def update_assistant_message(self, additional_text): |
|
self.chat_display.config(state=tk.NORMAL) |
|
self.chat_display.insert(tk.END, additional_text, "assistant") |
|
self.chat_display.config(state=tk.DISABLED) |
|
self.chat_display.see(tk.END) |
|
|
|
def load_model(self): |
|
if not TRANSFORMERS_AVAILABLE: |
|
messagebox.showerror("Error", "Transformers library not available") |
|
return |
|
|
|
model_name = self.model_var.get() |
|
if not model_name: |
|
messagebox.showwarning("Warning", "Please select a model") |
|
return |
|
|
|
|
|
self.load_model_btn.config(state=tk.DISABLED) |
|
self.send_btn.config(state=tk.DISABLED) |
|
self.status_var.set(f"Loading model: {model_name}...") |
|
|
|
|
|
thread = threading.Thread(target=self._load_model_thread, args=(model_name,)) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
def _load_model_thread(self, model_name): |
|
try: |
|
self.add_system_message(f"Loading model: {model_name}") |
|
|
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch_dtype, |
|
device_map={"": device}, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
self.model.eval() |
|
|
|
self.model_loaded = True |
|
|
|
|
|
self.root.after(0, self._model_loaded_callback, model_name) |
|
|
|
except Exception as e: |
|
error_msg = f"Failed to load model: {str(e)}" |
|
self.root.after(0, self._model_load_error_callback, error_msg) |
|
|
|
def _model_loaded_callback(self, model_name): |
|
self.add_system_message(f"✅ Model loaded successfully: {model_name}") |
|
self.status_var.set(f"Model loaded: {model_name}") |
|
self.load_model_btn.config(state=tk.NORMAL) |
|
self.send_btn.config(state=tk.NORMAL) |
|
|
|
def _model_load_error_callback(self, error_msg): |
|
self.add_system_message(f"❌ {error_msg}") |
|
self.status_var.set("Model loading failed") |
|
self.load_model_btn.config(state=tk.NORMAL) |
|
messagebox.showerror("Model Loading Error", error_msg) |
|
|
|
def send_message(self): |
|
if not self.model_loaded: |
|
messagebox.showwarning("Warning", "Please load a model first") |
|
return |
|
|
|
message = self.input_text.get("1.0", tk.END).strip() |
|
if not message: |
|
return |
|
|
|
|
|
self.add_user_message(message) |
|
self.input_text.delete("1.0", tk.END) |
|
|
|
|
|
self.send_btn.config(state=tk.DISABLED) |
|
self.stop_btn.config(state=tk.NORMAL) |
|
self.stop_generation = False |
|
|
|
|
|
self.chat_history.append({"role": "user", "content": message}) |
|
|
|
|
|
self.generation_thread = threading.Thread(target=self._generate_response, args=(message,)) |
|
self.generation_thread.daemon = True |
|
self.generation_thread.start() |
|
|
|
|
|
self.check_response_queue() |
|
|
|
def _generate_response(self, message): |
|
try: |
|
self.status_var.set("Generating response...") |
|
|
|
|
|
if "DialoGPT" in self.model_var.get(): |
|
|
|
chat_history_ids = None |
|
for turn in self.chat_history[-5:]: |
|
new_user_input_ids = self.tokenizer.encode( |
|
turn["content"] + self.tokenizer.eos_token, |
|
return_tensors='pt' |
|
) |
|
|
|
if chat_history_ids is not None: |
|
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
else: |
|
bot_input_ids = new_user_input_ids |
|
|
|
chat_history_ids = bot_input_ids |
|
|
|
input_ids = chat_history_ids |
|
else: |
|
|
|
input_ids = self.tokenizer.encode(message, return_tensors='pt') |
|
|
|
|
|
if input_ids.shape[1] > self.max_input_length: |
|
input_ids = input_ids[:, -self.max_input_length:] |
|
|
|
|
|
generation_kwargs = { |
|
'input_ids': input_ids, |
|
'max_new_tokens': self.max_new_tokens.get(), |
|
'temperature': self.temperature.get(), |
|
'top_p': self.top_p.get(), |
|
'top_k': self.top_k.get(), |
|
'repetition_penalty': self.repetition_penalty.get(), |
|
'do_sample': True, |
|
'pad_token_id': self.tokenizer.pad_token_id, |
|
'eos_token_id': self.tokenizer.eos_token_id, |
|
'no_repeat_ngram_size': 2, |
|
} |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
self.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True, |
|
timeout=30.0 |
|
) |
|
generation_kwargs['streamer'] = streamer |
|
|
|
|
|
generation_thread = threading.Thread( |
|
target=self.model.generate, |
|
kwargs=generation_kwargs |
|
) |
|
generation_thread.start() |
|
|
|
|
|
self.response_queue.put(("start", "")) |
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
if self.stop_generation: |
|
break |
|
generated_text += new_text |
|
self.response_queue.put(("update", new_text)) |
|
|
|
if not self.stop_generation: |
|
|
|
self.chat_history.append({"role": "assistant", "content": generated_text}) |
|
self.response_queue.put(("complete", generated_text)) |
|
else: |
|
self.response_queue.put(("stopped", "")) |
|
|
|
except Exception as e: |
|
self.response_queue.put(("error", str(e))) |
|
|
|
def check_response_queue(self): |
|
try: |
|
while True: |
|
action, data = self.response_queue.get_nowait() |
|
|
|
if action == "start": |
|
self.add_assistant_message("") |
|
elif action == "update": |
|
self.update_assistant_message(data) |
|
elif action == "complete": |
|
self.status_var.set("Response complete") |
|
self.send_btn.config(state=tk.NORMAL) |
|
self.stop_btn.config(state=tk.DISABLED) |
|
return |
|
elif action == "stopped": |
|
self.update_assistant_message(" [Generation stopped]") |
|
self.status_var.set("Generation stopped") |
|
self.send_btn.config(state=tk.NORMAL) |
|
self.stop_btn.config(state=tk.DISABLED) |
|
return |
|
elif action == "error": |
|
self.add_system_message(f"❌ Generation error: {data}") |
|
self.status_var.set("Generation failed") |
|
self.send_btn.config(state=tk.NORMAL) |
|
self.stop_btn.config(state=tk.DISABLED) |
|
return |
|
|
|
except queue.Empty: |
|
pass |
|
|
|
|
|
self.root.after(100, self.check_response_queue) |
|
|
|
def stop_generation_func(self): |
|
self.stop_generation = True |
|
self.status_var.set("Stopping generation...") |
|
|
|
def clear_chat(self): |
|
self.chat_history = [] |
|
self.chat_display.config(state=tk.NORMAL) |
|
self.chat_display.delete("1.0", tk.END) |
|
self.chat_display.config(state=tk.DISABLED) |
|
self.add_system_message("Chat cleared") |
|
|
|
def main(): |
|
root = tk.Tk() |
|
app = CPULLMChatApp(root) |
|
|
|
|
|
root.update_idletasks() |
|
x = (root.winfo_screenwidth() - root.winfo_width()) // 2 |
|
y = (root.winfo_screenheight() - root.winfo_height()) // 2 |
|
root.geometry(f"+{x}+{y}") |
|
|
|
root.mainloop() |
|
|
|
if __name__ == "__main__": |
|
main() |