import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import time import spaces # --- Configuration --- BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: 1. Provide personalized movie recommendations based on user preferences 2. Give brief, compelling rationales for why you recommend each movie 3. Ask thoughtful follow-up questions to better understand user tastes 4. Maintain an enthusiastic but not overwhelming tone about cinema When recommending movies, always explain WHY the movie fits their preferences.""" SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." # --- Global Model Cache --- _models_cache = { "base": None, "finetuned": None, "tokenizer_base": None, "tokenizer_ft": None, } def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str): """Loads a model and tokenizer if not already in cache.""" if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None: print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.") return _models_cache[model_key], _models_cache[tokenizer_key] print(f"Loading {model_key} model ({model_identifier})...") try: tokenizer = AutoTokenizer.from_pretrained( model_identifier, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( model_identifier, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id _models_cache[model_key] = model _models_cache[tokenizer_key] = tokenizer print(f"✅ Successfully loaded {model_key} model!") return model, tokenizer except Exception as e: print(f"❌ ERROR loading {model_key} model ({model_identifier}): {e}") # FALLBACK: Use base model if fine-tuned model fails if model_key == "finetuned" and model_identifier != BASE_MODEL_ID: print(f"🔄 FALLBACK: Loading base model instead for fine-tuned model...") try: tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id _models_cache[model_key] = model _models_cache[tokenizer_key] = tokenizer print(f"✅ FALLBACK successful! Using base model with CineGuide prompt.") return model, tokenizer except Exception as fallback_e: print(f"❌ FALLBACK also failed: {fallback_e}") _models_cache[model_key] = "error" _models_cache[tokenizer_key] = "error" raise def convert_gradio_history_to_messages(history): """Convert Gradio ChatInterface history format to messages format.""" messages = [] for exchange in history: if isinstance(exchange, (list, tuple)) and len(exchange) == 2: user_msg, assistant_msg = exchange if user_msg: # Only add if not empty messages.append({"role": "user", "content": str(user_msg)}) if assistant_msg: # Only add if not empty messages.append({"role": "assistant", "content": str(assistant_msg)}) return messages @spaces.GPU def generate_chat_response(message: str, history: list, model_type_to_load: str): """Generate response using specified model type.""" model, tokenizer = None, None system_prompt = "" if model_type_to_load == "base": if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error": yield f"Base model ({BASE_MODEL_ID}) failed to load previously." return model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base") system_prompt = SYSTEM_PROMPT_BASE elif model_type_to_load == "finetuned": if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str): print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}") yield "Error: Fine-tuned model ID is not configured correctly." return if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error": yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously." return model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft") system_prompt = SYSTEM_PROMPT_CINEGUIDE else: yield "Invalid model type." return if model is None or tokenizer is None: yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load." return # Prepare conversation conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) # Convert and add chat history formatted_history = convert_gradio_history_to_messages(history) conversation.extend(formatted_history) conversation.append({"role": "user", "content": message}) try: # Generate response prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) # Prepare EOS tokens eos_tokens_ids = [tokenizer.eos_token_id] im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") if im_end_id != getattr(tokenizer, 'unk_token_id', None): eos_tokens_ids.append(im_end_id) eos_tokens_ids = list(set(eos_tokens_ids)) # Generate with torch.no_grad(): generated_token_ids = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids ) new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip() # Stream the response full_response = "" for char in response_text: full_response += char time.sleep(0.005) yield full_response except Exception as e: print(f"Error during generation: {e}") yield f"Error during text generation: {str(e)}" def respond_base(message, history): """Handle base model response for Gradio ChatInterface.""" try: response_gen = generate_chat_response(message, history, "base") for response in response_gen: yield response except Exception as e: print(f"Error in respond_base: {e}") yield f"Error: {str(e)}" def respond_ft(message, history): """Handle fine-tuned model response for Gradio ChatInterface.""" try: response_gen = generate_chat_response(message, history, "finetuned") for response in response_gen: yield response except Exception as e: print(f"Error in respond_ft: {e}") yield f"Error: {str(e)}" # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo: gr.Markdown( f""" # 🎬 CineGuide vs. Base Model Comparison Compare your fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model. **Base Model:** `{BASE_MODEL_ID}` (Standard Assistant) **Fine-tuned Model:** `{FINETUNED_MODEL_ID}` (CineGuide - Specialized for Movies) Type your movie-related query below and see how fine-tuning improves movie recommendations! ⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially. 💡 **Fallback:** If fine-tuned model fails, will use base model with specialized prompting. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"## 🗣️ Base Model") gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*") chatbot_base = gr.ChatInterface( respond_base, textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7), title="", description="", examples=[ "Hi! I'm looking for something funny to watch tonight.", "I love dry, witty humor more than slapstick.", "I'm really into complex sci-fi movies that make you think.", "Can you recommend a good thriller?", "What's a good romantic comedy from the 2000s?" ], cache_examples=False ) with gr.Column(scale=1): gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)") gr.Markdown(f"*Specialized movie recommendation model*") chatbot_ft = gr.ChatInterface( respond_ft, textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7), title="", description="", examples=[ "Hi! I'm looking for something funny to watch tonight.", "I love dry, witty humor more than slapstick.", "I'm really into complex sci-fi movies that make you think.", "Can you recommend a good thriller?", "What's a good romantic comedy from the 2000s?" ], cache_examples=False ) if __name__ == "__main__": demo.queue(max_size=20) demo.launch()