LLM / app.py
Ocks's picture
Update app.py
eb19008 verified
#!/usr/bin/env python3
import os
import warnings
from collections.abc import Iterator
from threading import Thread
from typing import List, Dict, Optional, Tuple
import time
warnings.filterwarnings("ignore")
# Try to import required libraries
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
try:
import gradio as gr
GRADIO_AVAILABLE = True
except ImportError:
GRADIO_AVAILABLE = False
class CPULLMChat:
def __init__(self):
self.models = {
"microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)",
"microsoft/DialoGPT-small": "DialoGPT Small (Faster)",
"distilgpt2": "DistilGPT2 (Very fast)",
"gpt2": "GPT2 (Standard)",
"facebook/blenderbot-400M-distill": "BlenderBot (Conversational)"
}
self.current_model = None
self.current_tokenizer = None
self.current_model_name = None
self.model_loaded = False
# Configuration
self.max_input_length = 2048
self.device = "cpu"
def load_model(self, model_name: str, progress=gr.Progress()) -> str:
"""Load the selected model"""
if not TRANSFORMERS_AVAILABLE:
return "❌ Error: transformers library not installed. Run: pip install torch transformers"
if model_name == self.current_model_name and self.model_loaded:
return f"βœ… Model {model_name} is already loaded!"
try:
progress(0.1, desc="Loading tokenizer...")
# Load tokenizer
self.current_tokenizer = AutoTokenizer.from_pretrained(
model_name,
padding_side="left"
)
if self.current_tokenizer.pad_token is None:
self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
progress(0.5, desc="Loading model...")
# Load model with CPU optimizations
self.current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for CPU
device_map={"": self.device},
low_cpu_mem_usage=True
)
# Set to evaluation mode
self.current_model.eval()
self.current_model_name = model_name
self.model_loaded = True
progress(1.0, desc="Model loaded successfully!")
return f"βœ… Successfully loaded: {model_name}"
except Exception as e:
self.model_loaded = False
return f"❌ Failed to load model {model_name}: {str(e)}"
def generate_response(
self,
message: str,
chat_history: List[List[str]],
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
) -> Iterator[str]:
"""Generate response with streaming"""
if not self.model_loaded:
yield "❌ Please load a model first!"
return
if not message.strip():
yield "Please enter a message."
return
try:
# Prepare conversation context
conversation_text = ""
# Add chat history (last 5 exchanges to manage memory)
recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history
if "DialoGPT" in self.current_model_name:
# For DialoGPT, format as conversation
chat_history_ids = None
# Build conversation from history
for user_msg, bot_msg in recent_history:
if user_msg:
user_input_ids = self.current_tokenizer.encode(
user_msg + self.current_tokenizer.eos_token,
return_tensors='pt'
)
if chat_history_ids is not None:
chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1)
else:
chat_history_ids = user_input_ids
if bot_msg:
bot_input_ids = self.current_tokenizer.encode(
bot_msg + self.current_tokenizer.eos_token,
return_tensors='pt'
)
if chat_history_ids is not None:
chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1)
else:
chat_history_ids = bot_input_ids
# Add current message
new_user_input_ids = self.current_tokenizer.encode(
message + self.current_tokenizer.eos_token,
return_tensors='pt'
)
if chat_history_ids is not None:
input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
input_ids = new_user_input_ids
else:
# For other models, create context from history
for user_msg, bot_msg in recent_history:
if user_msg and bot_msg:
conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
conversation_text += f"User: {message}\nAssistant:"
input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt')
# Limit input length
if input_ids.shape[1] > self.max_input_length:
input_ids = input_ids[:, -self.max_input_length:]
# Set up streaming
streamer = TextIteratorStreamer(
self.current_tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = {
'input_ids': input_ids,
'streamer': streamer,
'max_new_tokens': max_new_tokens,
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'pad_token_id': self.current_tokenizer.pad_token_id,
'eos_token_id': self.current_tokenizer.eos_token_id,
'no_repeat_ngram_size': 2,
}
# Start generation in separate thread
generation_thread = Thread(
target=self.current_model.generate,
kwargs=generation_kwargs
)
generation_thread.start()
# Stream the response
partial_response = ""
for new_text in streamer:
partial_response += new_text
yield partial_response
except Exception as e:
yield f"❌ Generation error: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
if not GRADIO_AVAILABLE:
print("❌ Error: gradio library not installed. Run: pip install gradio")
return None
if not TRANSFORMERS_AVAILABLE:
print("❌ Error: transformers library not installed. Run: pip install torch transformers")
return None
# Initialize the chat system
chat_system = CPULLMChat()
# Custom CSS for better styling
css = """
.gradio-container {
max-width: 1200px;
margin: auto;
}
.chat-message {
padding: 10px;
margin: 5px 0;
border-radius: 10px;
}
.user-message {
background-color: #e3f2fd;
margin-left: 20%;
}
.bot-message {
background-color: #f1f8e9;
margin-right: 20%;
}
"""
with gr.Blocks(css=css, title="CPU LLM Chat") as demo:
gr.Markdown("# πŸ€– CPU-Optimized LLM Chat")
gr.Markdown("*A lightweight chat interface for running language models on CPU*")
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(chat_system.models.keys()),
value="microsoft/DialoGPT-medium",
label="Select Model",
info="Choose a model to load. DialoGPT models work best for chat."
)
load_btn = gr.Button("πŸ”„ Load Model", variant="primary")
model_status = gr.Textbox(
label="Model Status",
value="No model loaded",
interactive=False
)
with gr.Column(scale=1):
gr.Markdown("### πŸ’‘ Model Info")
gr.Markdown("""
- **DialoGPT Medium**: Best quality, slower
- **DialoGPT Small**: Good balance
- **DistilGPT2**: Fastest option
- **GPT2**: General purpose
- **BlenderBot**: Conversational AI
""")
# Chat interface
chatbot = gr.Chatbot(
label="Chat History",
height=400,
show_label=True,
container=True
)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here... (Press Ctrl+Enter to send)",
lines=3,
max_lines=10,
show_label=False
)
send_btn = gr.Button("πŸ“€ Send", variant="primary")
# Parameters section
with gr.Accordion("βš™οΈ Generation Parameters", open=False):
with gr.Row():
max_tokens = gr.Slider(
minimum=50,
maximum=512,
value=256,
step=10,
label="Max New Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values = more creative, lower = more focused"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p",
info="Nucleus sampling parameter"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k",
info="Top-k sampling parameter"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05,
label="Repetition Penalty",
info="Penalty for repeating tokens"
)
# Example messages
with gr.Accordion("πŸ’¬ Example Messages", open=False):
examples = [
"Hello! How are you today?",
"Tell me a short story about a robot.",
"What's the difference between AI and machine learning?",
"Can you help me write a poem about nature?",
"Explain quantum computing in simple terms.",
]
example_buttons = []
for example in examples:
btn = gr.Button(example, variant="secondary")
example_buttons.append(btn)
# Clear chat button
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
# Event handlers
def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
if not chat_system.model_loaded:
history.append([message, "❌ Please load a model first!"])
return history, ""
history.append([message, ""])
for partial_response in chat_system.generate_response(
message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty
):
history[-1][1] = partial_response
yield history, ""
def load_model_handler(model_name, progress=gr.Progress()):
return chat_system.load_model(model_name, progress)
def set_example(example_text):
return example_text
def clear_chat():
return [], ""
# Wire up events
load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status])
msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg])
send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg])
clear_btn.click(clear_chat, outputs=[chatbot, msg])
# Example buttons
for btn, example in zip(example_buttons, examples):
btn.click(set_example, inputs=[gr.State(example)], outputs=[msg])
# Footer
gr.Markdown("""
---
### πŸ“‹ Instructions:
1. **Select and load a model** using the dropdown and "Load Model" button
2. **Wait for the model to load** (may take 1-2 minutes on first load)
3. **Start chatting** once you see "βœ… Successfully loaded" message
4. **Adjust parameters** if needed for different response styles
### πŸ’» System Requirements:
- CPU with at least 4GB RAM available
- Python 3.8+ with torch and transformers installed
### ⚑ Performance Tips:
- Use DialoGPT-small for fastest responses
- Keep max tokens under 300 for better speed
- Lower temperature (0.3-0.7) for more consistent responses
""")
return demo
def main():
"""Main function to run the application"""
print("===== CPU LLM Chat Application =====")
print("Checking dependencies...")
if not GRADIO_AVAILABLE:
print("❌ Gradio not found. Install with: pip install gradio")
return
if not TRANSFORMERS_AVAILABLE:
print("❌ Transformers not found. Install with: pip install torch transformers")
return
print("βœ… All dependencies found!")
print("Starting web interface...")
try:
demo = create_interface()
if demo:
# Launch with appropriate settings
demo.queue(max_size=10).launch(
server_name="0.0.0.0", # Allow external access
server_port=7860, # Default Gradio port
share=False, # Set to True if you want a public link
show_error=True,
show_tips=True,
inbrowser=False # Don't try to open browser in headless env
)
except KeyboardInterrupt:
print("\nπŸ‘‹ Application stopped by user")
except Exception as e:
print(f"❌ Error starting application: {e}")
if __name__ == "__main__":
main()