#!/usr/bin/env python3 import os import warnings from collections.abc import Iterator from threading import Thread from typing import List, Dict, Optional, Tuple import time warnings.filterwarnings("ignore") # Try to import required libraries try: import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer ) TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False try: import gradio as gr GRADIO_AVAILABLE = True except ImportError: GRADIO_AVAILABLE = False class CPULLMChat: def __init__(self): self.models = { "microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)", "microsoft/DialoGPT-small": "DialoGPT Small (Faster)", "distilgpt2": "DistilGPT2 (Very fast)", "gpt2": "GPT2 (Standard)", "facebook/blenderbot-400M-distill": "BlenderBot (Conversational)" } self.current_model = None self.current_tokenizer = None self.current_model_name = None self.model_loaded = False # Configuration self.max_input_length = 2048 self.device = "cpu" def load_model(self, model_name: str, progress=gr.Progress()) -> str: """Load the selected model""" if not TRANSFORMERS_AVAILABLE: return "❌ Error: transformers library not installed. Run: pip install torch transformers" if model_name == self.current_model_name and self.model_loaded: return f"✅ Model {model_name} is already loaded!" try: progress(0.1, desc="Loading tokenizer...") # Load tokenizer self.current_tokenizer = AutoTokenizer.from_pretrained( model_name, padding_side="left" ) if self.current_tokenizer.pad_token is None: self.current_tokenizer.pad_token = self.current_tokenizer.eos_token progress(0.5, desc="Loading model...") # Load model with CPU optimizations self.current_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for CPU device_map={"": self.device}, low_cpu_mem_usage=True ) # Set to evaluation mode self.current_model.eval() self.current_model_name = model_name self.model_loaded = True progress(1.0, desc="Model loaded successfully!") return f"✅ Successfully loaded: {model_name}" except Exception as e: self.model_loaded = False return f"❌ Failed to load model {model_name}: {str(e)}" def generate_response( self, message: str, chat_history: List[List[str]], max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, ) -> Iterator[str]: """Generate response with streaming""" if not self.model_loaded: yield "❌ Please load a model first!" return if not message.strip(): yield "Please enter a message." return try: # Prepare conversation context conversation_text = "" # Add chat history (last 5 exchanges to manage memory) recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history if "DialoGPT" in self.current_model_name: # For DialoGPT, format as conversation chat_history_ids = None # Build conversation from history for user_msg, bot_msg in recent_history: if user_msg: user_input_ids = self.current_tokenizer.encode( user_msg + self.current_tokenizer.eos_token, return_tensors='pt' ) if chat_history_ids is not None: chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1) else: chat_history_ids = user_input_ids if bot_msg: bot_input_ids = self.current_tokenizer.encode( bot_msg + self.current_tokenizer.eos_token, return_tensors='pt' ) if chat_history_ids is not None: chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1) else: chat_history_ids = bot_input_ids # Add current message new_user_input_ids = self.current_tokenizer.encode( message + self.current_tokenizer.eos_token, return_tensors='pt' ) if chat_history_ids is not None: input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) else: input_ids = new_user_input_ids else: # For other models, create context from history for user_msg, bot_msg in recent_history: if user_msg and bot_msg: conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n" conversation_text += f"User: {message}\nAssistant:" input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt') # Limit input length if input_ids.shape[1] > self.max_input_length: input_ids = input_ids[:, -self.max_input_length:] # Set up streaming streamer = TextIteratorStreamer( self.current_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = { 'input_ids': input_ids, 'streamer': streamer, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k, 'repetition_penalty': repetition_penalty, 'do_sample': True, 'pad_token_id': self.current_tokenizer.pad_token_id, 'eos_token_id': self.current_tokenizer.eos_token_id, 'no_repeat_ngram_size': 2, } # Start generation in separate thread generation_thread = Thread( target=self.current_model.generate, kwargs=generation_kwargs ) generation_thread.start() # Stream the response partial_response = "" for new_text in streamer: partial_response += new_text yield partial_response except Exception as e: yield f"❌ Generation error: {str(e)}" def create_interface(): """Create the Gradio interface""" if not GRADIO_AVAILABLE: print("❌ Error: gradio library not installed. Run: pip install gradio") return None if not TRANSFORMERS_AVAILABLE: print("❌ Error: transformers library not installed. Run: pip install torch transformers") return None # Initialize the chat system chat_system = CPULLMChat() # Custom CSS for better styling css = """ .gradio-container { max-width: 1200px; margin: auto; } .chat-message { padding: 10px; margin: 5px 0; border-radius: 10px; } .user-message { background-color: #e3f2fd; margin-left: 20%; } .bot-message { background-color: #f1f8e9; margin-right: 20%; } """ with gr.Blocks(css=css, title="CPU LLM Chat") as demo: gr.Markdown("# 🤖 CPU-Optimized LLM Chat") gr.Markdown("*A lightweight chat interface for running language models on CPU*") with gr.Row(): with gr.Column(scale=2): model_dropdown = gr.Dropdown( choices=list(chat_system.models.keys()), value="microsoft/DialoGPT-medium", label="Select Model", info="Choose a model to load. DialoGPT models work best for chat." ) load_btn = gr.Button("🔄 Load Model", variant="primary") model_status = gr.Textbox( label="Model Status", value="No model loaded", interactive=False ) with gr.Column(scale=1): gr.Markdown("### 💡 Model Info") gr.Markdown(""" - **DialoGPT Medium**: Best quality, slower - **DialoGPT Small**: Good balance - **DistilGPT2**: Fastest option - **GPT2**: General purpose - **BlenderBot**: Conversational AI """) # Chat interface chatbot = gr.Chatbot( label="Chat History", height=400, show_label=True, container=True ) with gr.Row(): msg = gr.Textbox( label="Your Message", placeholder="Type your message here... (Press Ctrl+Enter to send)", lines=3, max_lines=10, show_label=False ) send_btn = gr.Button("📤 Send", variant="primary") # Parameters section with gr.Accordion("⚙️ Generation Parameters", open=False): with gr.Row(): max_tokens = gr.Slider( minimum=50, maximum=512, value=256, step=10, label="Max New Tokens", info="Maximum number of tokens to generate" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher values = more creative, lower = more focused" ) with gr.Row(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p", info="Nucleus sampling parameter" ) top_k = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-k", info="Top-k sampling parameter" ) repetition_penalty = gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Penalty for repeating tokens" ) # Example messages with gr.Accordion("💬 Example Messages", open=False): examples = [ "Hello! How are you today?", "Tell me a short story about a robot.", "What's the difference between AI and machine learning?", "Can you help me write a poem about nature?", "Explain quantum computing in simple terms.", ] example_buttons = [] for example in examples: btn = gr.Button(example, variant="secondary") example_buttons.append(btn) # Clear chat button clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") # Event handlers def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty): if not chat_system.model_loaded: history.append([message, "❌ Please load a model first!"]) return history, "" history.append([message, ""]) for partial_response in chat_system.generate_response( message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty ): history[-1][1] = partial_response yield history, "" def load_model_handler(model_name, progress=gr.Progress()): return chat_system.load_model(model_name, progress) def set_example(example_text): return example_text def clear_chat(): return [], "" # Wire up events load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status]) msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) clear_btn.click(clear_chat, outputs=[chatbot, msg]) # Example buttons for btn, example in zip(example_buttons, examples): btn.click(set_example, inputs=[gr.State(example)], outputs=[msg]) # Footer gr.Markdown(""" --- ### 📋 Instructions: 1. **Select and load a model** using the dropdown and "Load Model" button 2. **Wait for the model to load** (may take 1-2 minutes on first load) 3. **Start chatting** once you see "✅ Successfully loaded" message 4. **Adjust parameters** if needed for different response styles ### 💻 System Requirements: - CPU with at least 4GB RAM available - Python 3.8+ with torch and transformers installed ### ⚡ Performance Tips: - Use DialoGPT-small for fastest responses - Keep max tokens under 300 for better speed - Lower temperature (0.3-0.7) for more consistent responses """) return demo def main(): """Main function to run the application""" print("===== CPU LLM Chat Application =====") print("Checking dependencies...") if not GRADIO_AVAILABLE: print("❌ Gradio not found. Install with: pip install gradio") return if not TRANSFORMERS_AVAILABLE: print("❌ Transformers not found. Install with: pip install torch transformers") return print("✅ All dependencies found!") print("Starting web interface...") try: demo = create_interface() if demo: # Launch with appropriate settings demo.queue(max_size=10).launch( server_name="0.0.0.0", # Allow external access server_port=7860, # Default Gradio port share=False, # Set to True if you want a public link show_error=True, show_tips=True, inbrowser=False # Don't try to open browser in headless env ) except KeyboardInterrupt: print("\n👋 Application stopped by user") except Exception as e: print(f"❌ Error starting application: {e}") if __name__ == "__main__": main()