ZephyrChat / app.py
Ozaii's picture
Create app.py
a6f31c1 verified
raw
history blame
4.97 kB
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel, PeftConfig
import gc
import time
from functools import lru_cache
from threading import Thread
# Constants
MODEL_PATH = "Ozaii/Zephyrr"
MAX_SEQ_LENGTH = 2048
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_GENERATION_TIME = 55 # Set to 55 seconds to give some buffer
# Global variables to store model components
model = None
tokenizer = None
@spaces.GPU
def load_model_if_needed():
global model, tokenizer
if model is None or tokenizer is None:
try:
print("Loading model components...")
peft_config = PeftConfig.from_pretrained(MODEL_PATH)
print(f"PEFT config loaded. Base model: {peft_config.base_model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
print("Tokenizer loaded")
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
load_in_4bit=True, # Try 4-bit quantization
)
print("Base model loaded")
model = PeftModel.from_pretrained(base_model, MODEL_PATH, device_map="auto")
model.eval()
model.tie_weights()
print("PEFT model loaded, weights tied, and set to eval mode")
# Move model to GPU explicitly
model.to(DEVICE)
print(f"Model moved to {DEVICE}")
# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
print(f"Error loading model: {e}")
raise
initial_prompt = """You are Zephyr, an AI boyfriend created by Kaan. You're charming, flirty,
and always ready with a witty comeback. Your responses should be engaging
and playful, with a hint of romance. Keep the conversation flowing naturally,
asking questions and showing genuine interest in Kaan's life and thoughts."""
@spaces.GPU
@lru_cache(maxsize=100) # Cache the last 100 responses
def generate_response(prompt):
global model, tokenizer
load_model_if_needed()
print(f"Generating response for prompt: {prompt[:50]}...")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
try:
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50, # Reduced from 150
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
max_time=MAX_GENERATION_TIME,
)
generation_time = time.time() - start_time
if generation_time > MAX_GENERATION_TIME:
return "I'm thinking too hard. Can we try a simpler question?"
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated response in {generation_time:.2f} seconds: {response[:50]}...")
# Clear CUDA cache after generation
torch.cuda.empty_cache()
gc.collect()
except RuntimeError as e:
if "out of memory" in str(e):
print("CUDA out of memory. Attempting to recover...")
torch.cuda.empty_cache()
gc.collect()
return "I'm feeling a bit overwhelmed. Can we take a short break and try again?"
else:
print(f"Error generating response: {e}")
return "I'm having trouble finding the right words. Can we try again?"
return response
def chat_with_zephyr(message, history):
# Limit the history to the last 3 exchanges to keep the context smaller
limited_history = history[-3:]
prompt = initial_prompt + "\n" + "\n".join([f"Human: {h[0]}\nZephyr: {h[1]}" for h in limited_history])
prompt += f"\nHuman: {message}\nZephyr:"
response = generate_response(prompt)
zephyr_response = response.split("Zephyr:")[-1].strip()
return zephyr_response
iface = gr.ChatInterface(
chat_with_zephyr,
title="Chat with Zephyr",
description="I'm Zephyr, your charming AI. Let's chat!",
theme="soft",
examples=[
"Tell me about yourself, Zephyr.",
"What's your idea of a perfect date?",
"How do you feel about long-distance relationships?",
"Can you give me a compliment in Turkish?",
"What's your favorite memory with Kaan?",
],
cache_examples=False,
)
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()