Spaces:
Sleeping
Sleeping
File size: 4,972 Bytes
a6f31c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel, PeftConfig
import gc
import time
from functools import lru_cache
from threading import Thread
# Constants
MODEL_PATH = "Ozaii/Zephyrr"
MAX_SEQ_LENGTH = 2048
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_GENERATION_TIME = 55 # Set to 55 seconds to give some buffer
# Global variables to store model components
model = None
tokenizer = None
@spaces.GPU
def load_model_if_needed():
global model, tokenizer
if model is None or tokenizer is None:
try:
print("Loading model components...")
peft_config = PeftConfig.from_pretrained(MODEL_PATH)
print(f"PEFT config loaded. Base model: {peft_config.base_model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
print("Tokenizer loaded")
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
load_in_4bit=True, # Try 4-bit quantization
)
print("Base model loaded")
model = PeftModel.from_pretrained(base_model, MODEL_PATH, device_map="auto")
model.eval()
model.tie_weights()
print("PEFT model loaded, weights tied, and set to eval mode")
# Move model to GPU explicitly
model.to(DEVICE)
print(f"Model moved to {DEVICE}")
# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
print(f"Error loading model: {e}")
raise
initial_prompt = """You are Zephyr, an AI boyfriend created by Kaan. You're charming, flirty,
and always ready with a witty comeback. Your responses should be engaging
and playful, with a hint of romance. Keep the conversation flowing naturally,
asking questions and showing genuine interest in Kaan's life and thoughts."""
@spaces.GPU
@lru_cache(maxsize=100) # Cache the last 100 responses
def generate_response(prompt):
global model, tokenizer
load_model_if_needed()
print(f"Generating response for prompt: {prompt[:50]}...")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
try:
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50, # Reduced from 150
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
max_time=MAX_GENERATION_TIME,
)
generation_time = time.time() - start_time
if generation_time > MAX_GENERATION_TIME:
return "I'm thinking too hard. Can we try a simpler question?"
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated response in {generation_time:.2f} seconds: {response[:50]}...")
# Clear CUDA cache after generation
torch.cuda.empty_cache()
gc.collect()
except RuntimeError as e:
if "out of memory" in str(e):
print("CUDA out of memory. Attempting to recover...")
torch.cuda.empty_cache()
gc.collect()
return "I'm feeling a bit overwhelmed. Can we take a short break and try again?"
else:
print(f"Error generating response: {e}")
return "I'm having trouble finding the right words. Can we try again?"
return response
def chat_with_zephyr(message, history):
# Limit the history to the last 3 exchanges to keep the context smaller
limited_history = history[-3:]
prompt = initial_prompt + "\n" + "\n".join([f"Human: {h[0]}\nZephyr: {h[1]}" for h in limited_history])
prompt += f"\nHuman: {message}\nZephyr:"
response = generate_response(prompt)
zephyr_response = response.split("Zephyr:")[-1].strip()
return zephyr_response
iface = gr.ChatInterface(
chat_with_zephyr,
title="Chat with Zephyr",
description="I'm Zephyr, your charming AI. Let's chat!",
theme="soft",
examples=[
"Tell me about yourself, Zephyr.",
"What's your idea of a perfect date?",
"How do you feel about long-distance relationships?",
"Can you give me a compliment in Turkish?",
"What's your favorite memory with Kaan?",
],
cache_examples=False,
)
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch() |