File size: 10,942 Bytes
0c08550
 
6935641
0c08550
6935641
0c08550
 
 
6935641
0c08550
 
 
 
 
 
6935641
0c08550
 
 
6935641
 
 
 
 
 
18449fc
 
6935641
 
 
 
 
18449fc
6935641
 
fcf00e5
 
4d2f42d
fcf00e5
6935641
 
 
 
4d2f42d
6935641
 
 
 
 
 
 
 
 
 
4d2f42d
6935641
 
4d2f42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6935641
 
 
 
235bd9f
 
 
 
 
 
 
 
 
 
 
 
 
 
6935641
 
 
 
 
 
 
0c08550
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
6935641
 
18449fc
0c08550
f1ea8a0
6935641
 
 
 
235bd9f
 
 
0c08550
 
6935641
235bd9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18449fc
235bd9f
 
6935641
 
 
235bd9f
 
 
 
 
 
 
f1ea8a0
6935641
 
235bd9f
 
 
 
 
 
 
18449fc
 
6935641
0c08550
 
6935641
fcf00e5
6935641
fcf00e5
 
6935641
fcf00e5
6935641
 
4d2f42d
0c08550
 
6935641
0c08550
 
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
11a5899
6935641
 
0c08550
6935641
fcf00e5
6935641
 
 
 
 
 
 
 
 
 
 
 
11a5899
6935641
0c08550
 
6935641
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import spaces

# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"

SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema

When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."

# --- Global Model Cache ---
_models_cache = {
    "base": None,
    "finetuned": None,
    "tokenizer_base": None,
    "tokenizer_ft": None,
}

def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
    """Loads a model and tokenizer if not already in cache."""
    if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
        print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
        return _models_cache[model_key], _models_cache[tokenizer_key]

    print(f"Loading {model_key} model ({model_identifier})...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_identifier, 
            trust_remote_code=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        model.eval()

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
        
        _models_cache[model_key] = model
        _models_cache[tokenizer_key] = tokenizer
        print(f"βœ… Successfully loaded {model_key} model!")
        return model, tokenizer
    except Exception as e:
        print(f"❌ ERROR loading {model_key} model ({model_identifier}): {e}")
        
        # FALLBACK: Use base model if fine-tuned model fails
        if model_key == "finetuned" and model_identifier != BASE_MODEL_ID:
            print(f"πŸ”„ FALLBACK: Loading base model instead for fine-tuned model...")
            try:
                tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
                model = AutoModelForCausalLM.from_pretrained(
                    BASE_MODEL_ID,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    trust_remote_code=True
                )
                model.eval()
                if tokenizer.pad_token is None:
                    tokenizer.pad_token = tokenizer.eos_token
                    if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
                        tokenizer.pad_token_id = tokenizer.eos_token_id
                
                _models_cache[model_key] = model
                _models_cache[tokenizer_key] = tokenizer
                print(f"βœ… FALLBACK successful! Using base model with CineGuide prompt.")
                return model, tokenizer
            except Exception as fallback_e:
                print(f"❌ FALLBACK also failed: {fallback_e}")
        
        _models_cache[model_key] = "error"
        _models_cache[tokenizer_key] = "error"
        raise

def convert_gradio_history_to_messages(history):
    """Convert Gradio ChatInterface history format to messages format."""
    messages = []
    for exchange in history:
        if isinstance(exchange, (list, tuple)) and len(exchange) == 2:
            user_msg, assistant_msg = exchange
            if user_msg:  # Only add if not empty
                messages.append({"role": "user", "content": str(user_msg)})
            if assistant_msg:  # Only add if not empty
                messages.append({"role": "assistant", "content": str(assistant_msg)})
    return messages

@spaces.GPU
def generate_chat_response(message: str, history: list, model_type_to_load: str):
    """Generate response using specified model type."""
    model, tokenizer = None, None
    system_prompt = ""

    if model_type_to_load == "base":
        if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
            yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
        system_prompt = SYSTEM_PROMPT_BASE
    elif model_type_to_load == "finetuned":
        if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
            print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
            yield "Error: Fine-tuned model ID is not configured correctly."
            return
        if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
            yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
        system_prompt = SYSTEM_PROMPT_CINEGUIDE
    else:
        yield "Invalid model type."
        return

    if model is None or tokenizer is None:
        yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
        return

    # Prepare conversation
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    
    # Convert and add chat history
    formatted_history = convert_gradio_history_to_messages(history)
    conversation.extend(formatted_history)
    conversation.append({"role": "user", "content": message})

    try:
        # Generate response
        prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
        
        # Prepare EOS tokens
        eos_tokens_ids = [tokenizer.eos_token_id]
        im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        if im_end_id != getattr(tokenizer, 'unk_token_id', None):
            eos_tokens_ids.append(im_end_id)
        eos_tokens_ids = list(set(eos_tokens_ids))

        # Generate
        with torch.no_grad():
            generated_token_ids = model.generate(
                **inputs, 
                max_new_tokens=512, 
                do_sample=True, 
                temperature=0.7, 
                top_p=0.9,
                repetition_penalty=1.1, 
                pad_token_id=tokenizer.pad_token_id, 
                eos_token_id=eos_tokens_ids
            )
        
        new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
        response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()

        # Stream the response
        full_response = ""
        for char in response_text:
            full_response += char
            time.sleep(0.005)
            yield full_response
            
    except Exception as e:
        print(f"Error during generation: {e}")
        yield f"Error during text generation: {str(e)}"

def respond_base(message, history):
    """Handle base model response for Gradio ChatInterface."""
    try:
        response_gen = generate_chat_response(message, history, "base")
        for response in response_gen:
            yield response
    except Exception as e:
        print(f"Error in respond_base: {e}")
        yield f"Error: {str(e)}"

def respond_ft(message, history):
    """Handle fine-tuned model response for Gradio ChatInterface."""
    try:
        response_gen = generate_chat_response(message, history, "finetuned")
        for response in response_gen:
            yield response
    except Exception as e:
        print(f"Error in respond_ft: {e}")
        yield f"Error: {str(e)}"

# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
    gr.Markdown(
        f"""
        # 🎬 CineGuide vs. Base Model Comparison
        Compare your fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
        
        **Base Model:** `{BASE_MODEL_ID}` (Standard Assistant)
        **Fine-tuned Model:** `{FINETUNED_MODEL_ID}` (CineGuide - Specialized for Movies)
        
        Type your movie-related query below and see how fine-tuning improves movie recommendations!
        
        ⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
        πŸ’‘ **Fallback:** If fine-tuned model fails, will use base model with specialized prompting.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"## πŸ—£οΈ Base Model")
            gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
            chatbot_base = gr.ChatInterface(
                respond_base,
                textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
                title="",
                description="",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False
            )
            
        with gr.Column(scale=1):
            gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
            gr.Markdown(f"*Specialized movie recommendation model*")
            chatbot_ft = gr.ChatInterface(
                respond_ft,
                textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
                title="",
                description="",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False
            )

if __name__ == "__main__":
    demo.queue(max_size=20)
    demo.launch()