serhany's picture
Update app.py
4d2f42d verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import spaces
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Global Model Cache ---
_models_cache = {
"base": None,
"finetuned": None,
"tokenizer_base": None,
"tokenizer_ft": None,
}
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
"""Loads a model and tokenizer if not already in cache."""
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
return _models_cache[model_key], _models_cache[tokenizer_key]
print(f"Loading {model_key} model ({model_identifier})...")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_identifier,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_identifier,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_key] = model
_models_cache[tokenizer_key] = tokenizer
print(f"βœ… Successfully loaded {model_key} model!")
return model, tokenizer
except Exception as e:
print(f"❌ ERROR loading {model_key} model ({model_identifier}): {e}")
# FALLBACK: Use base model if fine-tuned model fails
if model_key == "finetuned" and model_identifier != BASE_MODEL_ID:
print(f"πŸ”„ FALLBACK: Loading base model instead for fine-tuned model...")
try:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_key] = model
_models_cache[tokenizer_key] = tokenizer
print(f"βœ… FALLBACK successful! Using base model with CineGuide prompt.")
return model, tokenizer
except Exception as fallback_e:
print(f"❌ FALLBACK also failed: {fallback_e}")
_models_cache[model_key] = "error"
_models_cache[tokenizer_key] = "error"
raise
def convert_gradio_history_to_messages(history):
"""Convert Gradio ChatInterface history format to messages format."""
messages = []
for exchange in history:
if isinstance(exchange, (list, tuple)) and len(exchange) == 2:
user_msg, assistant_msg = exchange
if user_msg: # Only add if not empty
messages.append({"role": "user", "content": str(user_msg)})
if assistant_msg: # Only add if not empty
messages.append({"role": "assistant", "content": str(assistant_msg)})
return messages
@spaces.GPU
def generate_chat_response(message: str, history: list, model_type_to_load: str):
"""Generate response using specified model type."""
model, tokenizer = None, None
system_prompt = ""
if model_type_to_load == "base":
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
system_prompt = SYSTEM_PROMPT_BASE
elif model_type_to_load == "finetuned":
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
yield "Error: Fine-tuned model ID is not configured correctly."
return
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
if model is None or tokenizer is None:
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
return
# Prepare conversation
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Convert and add chat history
formatted_history = convert_gradio_history_to_messages(history)
conversation.extend(formatted_history)
conversation.append({"role": "user", "content": message})
try:
# Generate response
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
# Prepare EOS tokens
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != getattr(tokenizer, 'unk_token_id', None):
eos_tokens_ids.append(im_end_id)
eos_tokens_ids = list(set(eos_tokens_ids))
# Generate
with torch.no_grad():
generated_token_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
# Stream the response
full_response = ""
for char in response_text:
full_response += char
time.sleep(0.005)
yield full_response
except Exception as e:
print(f"Error during generation: {e}")
yield f"Error during text generation: {str(e)}"
def respond_base(message, history):
"""Handle base model response for Gradio ChatInterface."""
try:
response_gen = generate_chat_response(message, history, "base")
for response in response_gen:
yield response
except Exception as e:
print(f"Error in respond_base: {e}")
yield f"Error: {str(e)}"
def respond_ft(message, history):
"""Handle fine-tuned model response for Gradio ChatInterface."""
try:
response_gen = generate_chat_response(message, history, "finetuned")
for response in response_gen:
yield response
except Exception as e:
print(f"Error in respond_ft: {e}")
yield f"Error: {str(e)}"
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base Model Comparison
Compare your fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
**Base Model:** `{BASE_MODEL_ID}` (Standard Assistant)
**Fine-tuned Model:** `{FINETUNED_MODEL_ID}` (CineGuide - Specialized for Movies)
Type your movie-related query below and see how fine-tuning improves movie recommendations!
⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
πŸ’‘ **Fallback:** If fine-tuned model fails, will use base model with specialized prompting.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## πŸ—£οΈ Base Model")
gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
chatbot_base = gr.ChatInterface(
respond_base,
textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
title="",
description="",
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
"Can you recommend a good thriller?",
"What's a good romantic comedy from the 2000s?"
],
cache_examples=False
)
with gr.Column(scale=1):
gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
gr.Markdown(f"*Specialized movie recommendation model*")
chatbot_ft = gr.ChatInterface(
respond_ft,
textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
title="",
description="",
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
"Can you recommend a good thriller?",
"What's a good romantic comedy from the 2000s?"
],
cache_examples=False
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()