Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import time | |
import spaces | |
# --- Configuration --- | |
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" | |
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" | |
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: | |
1. Provide personalized movie recommendations based on user preferences | |
2. Give brief, compelling rationales for why you recommend each movie | |
3. Ask thoughtful follow-up questions to better understand user tastes | |
4. Maintain an enthusiastic but not overwhelming tone about cinema | |
When recommending movies, always explain WHY the movie fits their preferences.""" | |
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." | |
# --- Global Model Cache --- | |
_models_cache = { | |
"base": None, | |
"finetuned": None, | |
"tokenizer_base": None, | |
"tokenizer_ft": None, | |
} | |
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str): | |
"""Loads a model and tokenizer if not already in cache.""" | |
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None: | |
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.") | |
return _models_cache[model_key], _models_cache[tokenizer_key] | |
print(f"Loading {model_key} model ({model_identifier})...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_identifier, | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_identifier, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model.eval() | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
_models_cache[model_key] = model | |
_models_cache[tokenizer_key] = tokenizer | |
print(f"β Successfully loaded {model_key} model!") | |
return model, tokenizer | |
except Exception as e: | |
print(f"β ERROR loading {model_key} model ({model_identifier}): {e}") | |
# FALLBACK: Use base model if fine-tuned model fails | |
if model_key == "finetuned" and model_identifier != BASE_MODEL_ID: | |
print(f"π FALLBACK: Loading base model instead for fine-tuned model...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model.eval() | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
_models_cache[model_key] = model | |
_models_cache[tokenizer_key] = tokenizer | |
print(f"β FALLBACK successful! Using base model with CineGuide prompt.") | |
return model, tokenizer | |
except Exception as fallback_e: | |
print(f"β FALLBACK also failed: {fallback_e}") | |
_models_cache[model_key] = "error" | |
_models_cache[tokenizer_key] = "error" | |
raise | |
def convert_gradio_history_to_messages(history): | |
"""Convert Gradio ChatInterface history format to messages format.""" | |
messages = [] | |
for exchange in history: | |
if isinstance(exchange, (list, tuple)) and len(exchange) == 2: | |
user_msg, assistant_msg = exchange | |
if user_msg: # Only add if not empty | |
messages.append({"role": "user", "content": str(user_msg)}) | |
if assistant_msg: # Only add if not empty | |
messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
return messages | |
def generate_chat_response(message: str, history: list, model_type_to_load: str): | |
"""Generate response using specified model type.""" | |
model, tokenizer = None, None | |
system_prompt = "" | |
if model_type_to_load == "base": | |
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error": | |
yield f"Base model ({BASE_MODEL_ID}) failed to load previously." | |
return | |
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base") | |
system_prompt = SYSTEM_PROMPT_BASE | |
elif model_type_to_load == "finetuned": | |
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str): | |
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}") | |
yield "Error: Fine-tuned model ID is not configured correctly." | |
return | |
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error": | |
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously." | |
return | |
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft") | |
system_prompt = SYSTEM_PROMPT_CINEGUIDE | |
else: | |
yield "Invalid model type." | |
return | |
if model is None or tokenizer is None: | |
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load." | |
return | |
# Prepare conversation | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
# Convert and add chat history | |
formatted_history = convert_gradio_history_to_messages(history) | |
conversation.extend(formatted_history) | |
conversation.append({"role": "user", "content": message}) | |
try: | |
# Generate response | |
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) | |
# Prepare EOS tokens | |
eos_tokens_ids = [tokenizer.eos_token_id] | |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
if im_end_id != getattr(tokenizer, 'unk_token_id', None): | |
eos_tokens_ids.append(im_end_id) | |
eos_tokens_ids = list(set(eos_tokens_ids)) | |
# Generate | |
with torch.no_grad(): | |
generated_token_ids = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=eos_tokens_ids | |
) | |
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] | |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip() | |
# Stream the response | |
full_response = "" | |
for char in response_text: | |
full_response += char | |
time.sleep(0.005) | |
yield full_response | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
yield f"Error during text generation: {str(e)}" | |
def respond_base(message, history): | |
"""Handle base model response for Gradio ChatInterface.""" | |
try: | |
response_gen = generate_chat_response(message, history, "base") | |
for response in response_gen: | |
yield response | |
except Exception as e: | |
print(f"Error in respond_base: {e}") | |
yield f"Error: {str(e)}" | |
def respond_ft(message, history): | |
"""Handle fine-tuned model response for Gradio ChatInterface.""" | |
try: | |
response_gen = generate_chat_response(message, history, "finetuned") | |
for response in response_gen: | |
yield response | |
except Exception as e: | |
print(f"Error in respond_ft: {e}") | |
yield f"Error: {str(e)}" | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="π¬ CineGuide Comparison") as demo: | |
gr.Markdown( | |
f""" | |
# π¬ CineGuide vs. Base Model Comparison | |
Compare your fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model. | |
**Base Model:** `{BASE_MODEL_ID}` (Standard Assistant) | |
**Fine-tuned Model:** `{FINETUNED_MODEL_ID}` (CineGuide - Specialized for Movies) | |
Type your movie-related query below and see how fine-tuning improves movie recommendations! | |
β οΈ **Note:** Models are loaded on first use and may take 30-60 seconds initially. | |
π‘ **Fallback:** If fine-tuned model fails, will use base model with specialized prompting. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π£οΈ Base Model") | |
gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*") | |
chatbot_base = gr.ChatInterface( | |
respond_base, | |
textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7), | |
title="", | |
description="", | |
examples=[ | |
"Hi! I'm looking for something funny to watch tonight.", | |
"I love dry, witty humor more than slapstick.", | |
"I'm really into complex sci-fi movies that make you think.", | |
"Can you recommend a good thriller?", | |
"What's a good romantic comedy from the 2000s?" | |
], | |
cache_examples=False | |
) | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π¬ CineGuide (Fine-tuned)") | |
gr.Markdown(f"*Specialized movie recommendation model*") | |
chatbot_ft = gr.ChatInterface( | |
respond_ft, | |
textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7), | |
title="", | |
description="", | |
examples=[ | |
"Hi! I'm looking for something funny to watch tonight.", | |
"I love dry, witty humor more than slapstick.", | |
"I'm really into complex sci-fi movies that make you think.", | |
"Can you recommend a good thriller?", | |
"What's a good romantic comedy from the 2000s?" | |
], | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20) | |
demo.launch() |