|
|
|
|
|
import os |
|
import warnings |
|
from collections.abc import Iterator |
|
from threading import Thread |
|
from typing import List, Dict, Optional, Tuple |
|
import time |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
try: |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
TextIteratorStreamer |
|
) |
|
TRANSFORMERS_AVAILABLE = True |
|
except ImportError: |
|
TRANSFORMERS_AVAILABLE = False |
|
|
|
try: |
|
import gradio as gr |
|
GRADIO_AVAILABLE = True |
|
except ImportError: |
|
GRADIO_AVAILABLE = False |
|
|
|
class CPULLMChat: |
|
def __init__(self): |
|
self.models = { |
|
"microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)", |
|
"microsoft/DialoGPT-small": "DialoGPT Small (Faster)", |
|
"distilgpt2": "DistilGPT2 (Very fast)", |
|
"gpt2": "GPT2 (Standard)", |
|
"facebook/blenderbot-400M-distill": "BlenderBot (Conversational)" |
|
} |
|
|
|
self.current_model = None |
|
self.current_tokenizer = None |
|
self.current_model_name = None |
|
self.model_loaded = False |
|
|
|
|
|
self.max_input_length = 2048 |
|
self.device = "cpu" |
|
|
|
def load_model(self, model_name: str, progress=gr.Progress()) -> str: |
|
"""Load the selected model""" |
|
if not TRANSFORMERS_AVAILABLE: |
|
return "β Error: transformers library not installed. Run: pip install torch transformers" |
|
|
|
if model_name == self.current_model_name and self.model_loaded: |
|
return f"β
Model {model_name} is already loaded!" |
|
|
|
try: |
|
progress(0.1, desc="Loading tokenizer...") |
|
|
|
|
|
self.current_tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
padding_side="left" |
|
) |
|
if self.current_tokenizer.pad_token is None: |
|
self.current_tokenizer.pad_token = self.current_tokenizer.eos_token |
|
|
|
progress(0.5, desc="Loading model...") |
|
|
|
|
|
self.current_model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32, |
|
device_map={"": self.device}, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
self.current_model.eval() |
|
|
|
self.current_model_name = model_name |
|
self.model_loaded = True |
|
|
|
progress(1.0, desc="Model loaded successfully!") |
|
|
|
return f"β
Successfully loaded: {model_name}" |
|
|
|
except Exception as e: |
|
self.model_loaded = False |
|
return f"β Failed to load model {model_name}: {str(e)}" |
|
|
|
def generate_response( |
|
self, |
|
message: str, |
|
chat_history: List[List[str]], |
|
max_new_tokens: int = 256, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.1, |
|
) -> Iterator[str]: |
|
"""Generate response with streaming""" |
|
|
|
if not self.model_loaded: |
|
yield "β Please load a model first!" |
|
return |
|
|
|
if not message.strip(): |
|
yield "Please enter a message." |
|
return |
|
|
|
try: |
|
|
|
conversation_text = "" |
|
|
|
|
|
recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history |
|
|
|
if "DialoGPT" in self.current_model_name: |
|
|
|
chat_history_ids = None |
|
|
|
|
|
for user_msg, bot_msg in recent_history: |
|
if user_msg: |
|
user_input_ids = self.current_tokenizer.encode( |
|
user_msg + self.current_tokenizer.eos_token, |
|
return_tensors='pt' |
|
) |
|
if chat_history_ids is not None: |
|
chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1) |
|
else: |
|
chat_history_ids = user_input_ids |
|
|
|
if bot_msg: |
|
bot_input_ids = self.current_tokenizer.encode( |
|
bot_msg + self.current_tokenizer.eos_token, |
|
return_tensors='pt' |
|
) |
|
if chat_history_ids is not None: |
|
chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1) |
|
else: |
|
chat_history_ids = bot_input_ids |
|
|
|
|
|
new_user_input_ids = self.current_tokenizer.encode( |
|
message + self.current_tokenizer.eos_token, |
|
return_tensors='pt' |
|
) |
|
|
|
if chat_history_ids is not None: |
|
input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
else: |
|
input_ids = new_user_input_ids |
|
|
|
else: |
|
|
|
for user_msg, bot_msg in recent_history: |
|
if user_msg and bot_msg: |
|
conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
|
|
|
conversation_text += f"User: {message}\nAssistant:" |
|
input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt') |
|
|
|
|
|
if input_ids.shape[1] > self.max_input_length: |
|
input_ids = input_ids[:, -self.max_input_length:] |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
self.current_tokenizer, |
|
timeout=60.0, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
generation_kwargs = { |
|
'input_ids': input_ids, |
|
'streamer': streamer, |
|
'max_new_tokens': max_new_tokens, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'top_k': top_k, |
|
'repetition_penalty': repetition_penalty, |
|
'do_sample': True, |
|
'pad_token_id': self.current_tokenizer.pad_token_id, |
|
'eos_token_id': self.current_tokenizer.eos_token_id, |
|
'no_repeat_ngram_size': 2, |
|
} |
|
|
|
|
|
generation_thread = Thread( |
|
target=self.current_model.generate, |
|
kwargs=generation_kwargs |
|
) |
|
generation_thread.start() |
|
|
|
|
|
partial_response = "" |
|
for new_text in streamer: |
|
partial_response += new_text |
|
yield partial_response |
|
|
|
except Exception as e: |
|
yield f"β Generation error: {str(e)}" |
|
|
|
def create_interface(): |
|
"""Create the Gradio interface""" |
|
|
|
if not GRADIO_AVAILABLE: |
|
print("β Error: gradio library not installed. Run: pip install gradio") |
|
return None |
|
|
|
if not TRANSFORMERS_AVAILABLE: |
|
print("β Error: transformers library not installed. Run: pip install torch transformers") |
|
return None |
|
|
|
|
|
chat_system = CPULLMChat() |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
max-width: 1200px; |
|
margin: auto; |
|
} |
|
.chat-message { |
|
padding: 10px; |
|
margin: 5px 0; |
|
border-radius: 10px; |
|
} |
|
.user-message { |
|
background-color: #e3f2fd; |
|
margin-left: 20%; |
|
} |
|
.bot-message { |
|
background-color: #f1f8e9; |
|
margin-right: 20%; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, title="CPU LLM Chat") as demo: |
|
gr.Markdown("# π€ CPU-Optimized LLM Chat") |
|
gr.Markdown("*A lightweight chat interface for running language models on CPU*") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(chat_system.models.keys()), |
|
value="microsoft/DialoGPT-medium", |
|
label="Select Model", |
|
info="Choose a model to load. DialoGPT models work best for chat." |
|
) |
|
load_btn = gr.Button("π Load Model", variant="primary") |
|
model_status = gr.Textbox( |
|
label="Model Status", |
|
value="No model loaded", |
|
interactive=False |
|
) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### π‘ Model Info") |
|
gr.Markdown(""" |
|
- **DialoGPT Medium**: Best quality, slower |
|
- **DialoGPT Small**: Good balance |
|
- **DistilGPT2**: Fastest option |
|
- **GPT2**: General purpose |
|
- **BlenderBot**: Conversational AI |
|
""") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="Chat History", |
|
height=400, |
|
show_label=True, |
|
container=True |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
label="Your Message", |
|
placeholder="Type your message here... (Press Ctrl+Enter to send)", |
|
lines=3, |
|
max_lines=10, |
|
show_label=False |
|
) |
|
send_btn = gr.Button("π€ Send", variant="primary") |
|
|
|
|
|
with gr.Accordion("βοΈ Generation Parameters", open=False): |
|
with gr.Row(): |
|
max_tokens = gr.Slider( |
|
minimum=50, |
|
maximum=512, |
|
value=256, |
|
step=10, |
|
label="Max New Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Higher values = more creative, lower = more focused" |
|
) |
|
|
|
with gr.Row(): |
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
label="Top-p", |
|
info="Nucleus sampling parameter" |
|
) |
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
label="Top-k", |
|
info="Top-k sampling parameter" |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.05, |
|
label="Repetition Penalty", |
|
info="Penalty for repeating tokens" |
|
) |
|
|
|
|
|
with gr.Accordion("π¬ Example Messages", open=False): |
|
examples = [ |
|
"Hello! How are you today?", |
|
"Tell me a short story about a robot.", |
|
"What's the difference between AI and machine learning?", |
|
"Can you help me write a poem about nature?", |
|
"Explain quantum computing in simple terms.", |
|
] |
|
|
|
example_buttons = [] |
|
for example in examples: |
|
btn = gr.Button(example, variant="secondary") |
|
example_buttons.append(btn) |
|
|
|
|
|
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") |
|
|
|
|
|
def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty): |
|
if not chat_system.model_loaded: |
|
history.append([message, "β Please load a model first!"]) |
|
return history, "" |
|
|
|
history.append([message, ""]) |
|
|
|
for partial_response in chat_system.generate_response( |
|
message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty |
|
): |
|
history[-1][1] = partial_response |
|
yield history, "" |
|
|
|
def load_model_handler(model_name, progress=gr.Progress()): |
|
return chat_system.load_model(model_name, progress) |
|
|
|
def set_example(example_text): |
|
return example_text |
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
|
|
load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status]) |
|
|
|
msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
|
send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
|
|
|
|
|
for btn, example in zip(example_buttons, examples): |
|
btn.click(set_example, inputs=[gr.State(example)], outputs=[msg]) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
### π Instructions: |
|
1. **Select and load a model** using the dropdown and "Load Model" button |
|
2. **Wait for the model to load** (may take 1-2 minutes on first load) |
|
3. **Start chatting** once you see "β
Successfully loaded" message |
|
4. **Adjust parameters** if needed for different response styles |
|
|
|
### π» System Requirements: |
|
- CPU with at least 4GB RAM available |
|
- Python 3.8+ with torch and transformers installed |
|
|
|
### β‘ Performance Tips: |
|
- Use DialoGPT-small for fastest responses |
|
- Keep max tokens under 300 for better speed |
|
- Lower temperature (0.3-0.7) for more consistent responses |
|
""") |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main function to run the application""" |
|
|
|
print("===== CPU LLM Chat Application =====") |
|
print("Checking dependencies...") |
|
|
|
if not GRADIO_AVAILABLE: |
|
print("β Gradio not found. Install with: pip install gradio") |
|
return |
|
|
|
if not TRANSFORMERS_AVAILABLE: |
|
print("β Transformers not found. Install with: pip install torch transformers") |
|
return |
|
|
|
print("β
All dependencies found!") |
|
print("Starting web interface...") |
|
|
|
try: |
|
demo = create_interface() |
|
if demo: |
|
|
|
demo.queue(max_size=10).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True, |
|
show_tips=True, |
|
inbrowser=False |
|
) |
|
except KeyboardInterrupt: |
|
print("\nπ Application stopped by user") |
|
except Exception as e: |
|
print(f"β Error starting application: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |