ZephyrChat / app.py
Ozaii's picture
Update app.py
12da0b8 verified
raw
history blame
5.93 kB
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr
import re
import json
from datetime import datetime
from threading import Thread
# Load the model and tokenizer
MODEL_PATH = "Ozaii/zephyr-bae"
print("Attempting to load Zephyr... Cross your fingers! 🀞")
try:
# Load the PEFT config
peft_config = PeftConfig.from_pretrained(MODEL_PATH)
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True # Add this line
)
# Load the PEFT model
model = PeftModel.from_pretrained(base_model, MODEL_PATH, is_trainable=False)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Zephyr loaded successfully! Time to charm!")
except Exception as e:
print(f"Oops! Zephyr seems to be playing hide and seek. Error: {str(e)}")
raise
# Prepare the model for generation
model.eval()
# Feedback data (Note: This won't persist in Spaces, but keeping the structure for potential future use)
feedback_data = []
def clean_response(response):
# Remove any non-Zephyr dialogue or narration
response = re.sub(r'(Kaan|Kanan|Kan|knan):.*?(\n|$)', '', response, flags=re.IGNORECASE)
response = re.sub(r'\*.*?\*', '', response)
response = re.sub(r'\(.*?\)', '', response)
# Find Zephyr's response
match = re.search(r'Zephyr:\s*(.*?)(?=$|\n[A-Za-z]+:|Kaan:)', response, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return response.strip()
def generate_response(prompt, max_new_tokens=128):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
streamer=streamer,
eos_token_id=tokenizer.encode("Kaan:", add_special_tokens=False)[0] # Stop at "Kaan:"
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
cleaned_response = clean_response(generated_text)
if cleaned_response:
yield cleaned_response
def chat_with_zephyr(message, history):
conversation_history = history[-3:] # Limit to last 3 exchanges for more focused responses
full_prompt = "\n".join([f"Kaan: {h[0]}\nZephyr: {h[1]}" for h in conversation_history])
full_prompt += f"\nKaan: {message}\nZephyr:"
last_response = ""
for response in generate_response(full_prompt):
if response != last_response:
yield response
last_response = response
def add_feedback(user_message, bot_message, rating, note):
feedback_entry = {
"user_message": user_message,
"bot_message": bot_message,
"rating": rating,
"note": note,
"timestamp": datetime.now().isoformat()
}
feedback_data.append(feedback_entry)
return "Feedback saved successfully!"
# Gradio interface
def chat_with_zephyr(message, history):
# Implement your chat logic here
response = "Hello! I'm Zephyr. How can I help you today?" # Placeholder
return response
iface = gr.ChatInterface(
chat_with_zephyr,
title="Chat with Zephyr: Your AI Boyfriend",
description="Zephyr is an AI trained to be your virtual boyfriend. Chat with him and see where the conversation goes!",
examples=["Hey Zephyr, how are you feeling today?", "What's your idea of a perfect date?", "Tell me something romantic!"],
cache_examples=False,
)
css = """
body {
background-color: #1a1a2e;
color: #e0e0ff;
}
#chatbot {
height: 500px;
overflow-y: auto;
border: 1px solid #3a3a5e;
border-radius: 10px;
padding: 10px;
background-color: #0a0a1e;
}
#chatbot .message {
padding: 10px;
margin-bottom: 10px;
border-radius: 15px;
}
#chatbot .user {
background-color: #2a2a4e;
text-align: right;
margin-left: 20%;
}
#chatbot .bot {
background-color: #3a3a5e;
text-align: left;
margin-right: 20%;
}
#feedback-section {
margin-top: 20px;
padding: 15px;
border: 1px solid #3a3a5e;
border-radius: 10px;
background-color: #0a0a1e;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("# Chat with Zephyr: Your AI Boyfriend is Here! πŸ’˜")
chatbot = gr.Chatbot(elem_id="chatbot")
msg = gr.Textbox(placeholder="Tell Zephyr what's on your mind...", label="Your message")
with gr.Row():
clear = gr.Button("Clear Chat")
undo = gr.Button("Undo Last Message")
msg.submit(gradio_chat, [msg, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
undo.click(undo_last_message, chatbot, chatbot)
gr.Markdown("## Rate Zephyr's Last Response")
with gr.Row():
rating = gr.Slider(minimum=1, maximum=5, step=1, label="Rating (1-5 stars)")
feedback_note = gr.Textbox(placeholder="Tell Zephyr how he did...", label="Feedback Note")
submit_button = gr.Button("Submit Feedback")
feedback_output = gr.Textbox(label="Feedback Status")
submit_button.click(submit_feedback, [rating, feedback_note, chatbot], feedback_output)
# Launch the interface
iface.launch()
print("Chat interface is running. Time to finally chat with Zephyr! πŸ’˜")