Spaces:
Sleeping
Sleeping
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import gradio as gr | |
import re | |
import json | |
from datetime import datetime | |
from threading import Thread | |
# Load the model and tokenizer | |
MODEL_PATH = "Ozaii/zephyr-bae" | |
print("Attempting to load Zephyr... Cross your fingers! π€") | |
try: | |
# Load the PEFT config | |
peft_config = PeftConfig.from_pretrained(MODEL_PATH) | |
# Load the base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
peft_config.base_model_name_or_path, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map="auto", | |
trust_remote_code=True # Add this line | |
) | |
# Load the PEFT model | |
model = PeftModel.from_pretrained(base_model, MODEL_PATH, is_trainable=False) | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
print("Zephyr loaded successfully! Time to charm!") | |
except Exception as e: | |
print(f"Oops! Zephyr seems to be playing hide and seek. Error: {str(e)}") | |
raise | |
# Prepare the model for generation | |
model.eval() | |
# Feedback data (Note: This won't persist in Spaces, but keeping the structure for potential future use) | |
feedback_data = [] | |
def clean_response(response): | |
# Remove any non-Zephyr dialogue or narration | |
response = re.sub(r'(Kaan|Kanan|Kan|knan):.*?(\n|$)', '', response, flags=re.IGNORECASE) | |
response = re.sub(r'\*.*?\*', '', response) | |
response = re.sub(r'\(.*?\)', '', response) | |
# Find Zephyr's response | |
match = re.search(r'Zephyr:\s*(.*?)(?=$|\n[A-Za-z]+:|Kaan:)', response, re.DOTALL | re.IGNORECASE) | |
if match: | |
return match.group(1).strip() | |
else: | |
return response.strip() | |
def generate_response(prompt, max_new_tokens=128): | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
input_ids=inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3, | |
streamer=streamer, | |
eos_token_id=tokenizer.encode("Kaan:", add_special_tokens=False)[0] # Stop at "Kaan:" | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
cleaned_response = clean_response(generated_text) | |
if cleaned_response: | |
yield cleaned_response | |
def chat_with_zephyr(message, history): | |
conversation_history = history[-3:] # Limit to last 3 exchanges for more focused responses | |
full_prompt = "\n".join([f"Kaan: {h[0]}\nZephyr: {h[1]}" for h in conversation_history]) | |
full_prompt += f"\nKaan: {message}\nZephyr:" | |
last_response = "" | |
for response in generate_response(full_prompt): | |
if response != last_response: | |
yield response | |
last_response = response | |
def add_feedback(user_message, bot_message, rating, note): | |
feedback_entry = { | |
"user_message": user_message, | |
"bot_message": bot_message, | |
"rating": rating, | |
"note": note, | |
"timestamp": datetime.now().isoformat() | |
} | |
feedback_data.append(feedback_entry) | |
return "Feedback saved successfully!" | |
# Gradio interface | |
def chat_with_zephyr(message, history): | |
# Implement your chat logic here | |
response = "Hello! I'm Zephyr. How can I help you today?" # Placeholder | |
return response | |
iface = gr.ChatInterface( | |
chat_with_zephyr, | |
title="Chat with Zephyr: Your AI Boyfriend", | |
description="Zephyr is an AI trained to be your virtual boyfriend. Chat with him and see where the conversation goes!", | |
examples=["Hey Zephyr, how are you feeling today?", "What's your idea of a perfect date?", "Tell me something romantic!"], | |
cache_examples=False, | |
) | |
css = """ | |
body { | |
background-color: #1a1a2e; | |
color: #e0e0ff; | |
} | |
#chatbot { | |
height: 500px; | |
overflow-y: auto; | |
border: 1px solid #3a3a5e; | |
border-radius: 10px; | |
padding: 10px; | |
background-color: #0a0a1e; | |
} | |
#chatbot .message { | |
padding: 10px; | |
margin-bottom: 10px; | |
border-radius: 15px; | |
} | |
#chatbot .user { | |
background-color: #2a2a4e; | |
text-align: right; | |
margin-left: 20%; | |
} | |
#chatbot .bot { | |
background-color: #3a3a5e; | |
text-align: left; | |
margin-right: 20%; | |
} | |
#feedback-section { | |
margin-top: 20px; | |
padding: 15px; | |
border: 1px solid #3a3a5e; | |
border-radius: 10px; | |
background-color: #0a0a1e; | |
} | |
""" | |
with gr.Blocks(css=css) as iface: | |
gr.Markdown("# Chat with Zephyr: Your AI Boyfriend is Here! π") | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
msg = gr.Textbox(placeholder="Tell Zephyr what's on your mind...", label="Your message") | |
with gr.Row(): | |
clear = gr.Button("Clear Chat") | |
undo = gr.Button("Undo Last Message") | |
msg.submit(gradio_chat, [msg, chatbot], [chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
undo.click(undo_last_message, chatbot, chatbot) | |
gr.Markdown("## Rate Zephyr's Last Response") | |
with gr.Row(): | |
rating = gr.Slider(minimum=1, maximum=5, step=1, label="Rating (1-5 stars)") | |
feedback_note = gr.Textbox(placeholder="Tell Zephyr how he did...", label="Feedback Note") | |
submit_button = gr.Button("Submit Feedback") | |
feedback_output = gr.Textbox(label="Feedback Status") | |
submit_button.click(submit_feedback, [rating, feedback_note, chatbot], feedback_output) | |
# Launch the interface | |
iface.launch() | |
print("Chat interface is running. Time to finally chat with Zephyr! π") |