serhany commited on
Commit
6935641
·
verified ·
1 Parent(s): f1ea8a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -174
app.py CHANGED
@@ -1,223 +1,246 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
  import time
5
- import os
6
-
7
- # Attempt to import the spaces GPU decorator.
8
- # This is a common pattern, but the exact import might vary or be injected.
9
- try:
10
- import spaces # This might make spaces.GPU available
11
- except ImportError:
12
- spaces = None # Define it as None if import fails, so we can check later
13
- print("WARNING: 'spaces' module not found. @spaces.GPU decorator might not be available or work as expected.")
14
-
15
 
16
  # --- Configuration ---
17
  BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
18
- FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Confirmed by you as correct
19
 
20
  SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
21
  1. Provide personalized movie recommendations based on user preferences
22
  2. Give brief, compelling rationales for why you recommend each movie
23
  3. Ask thoughtful follow-up questions to better understand user tastes
24
  4. Maintain an enthusiastic but not overwhelming tone about cinema
 
25
  When recommending movies, always explain WHY the movie fits their preferences."""
26
  SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
27
 
28
- # --- Global Model Storage (placeholders) ---
29
- # We will store model objects here after they are loaded within a GPU context.
30
- # This addresses John6666's point about global variables not updating correctly
31
- # if modified outside the main Gradio event flow or GPU context.
32
- # We'll treat these more like a cache that's populated by GPU-context functions.
33
- MODELS_LOADED = {
34
- "base_model": None,
35
- "base_tokenizer": None,
36
- "ft_model": None,
37
- "ft_tokenizer": None,
38
- "base_load_error": None,
39
- "ft_load_error": None,
40
  }
41
 
42
- # --- Core Model Loading and Inference Logic (to be wrapped by @spaces.GPU) ---
43
- def _load_and_infer(message: str, chat_history: list, model_id_to_load: str, system_prompt: str, model_kind: str):
44
- """
45
- This function handles loading (if necessary) and inference.
46
- It's designed to be called by a function decorated with @spaces.GPU.
47
- """
48
- model_key = f"{model_kind}_model"
49
- tokenizer_key = f"{model_kind}_tokenizer"
50
- error_key = f"{model_kind}_load_error"
51
-
52
- # Check if model failed to load previously
53
- if MODELS_LOADED[error_key]:
54
- yield f"Previous attempt to load {model_kind} model ({model_id_to_load}) failed: {MODELS_LOADED[error_key]}"
55
- return
56
 
57
- # Load model and tokenizer if not already loaded
58
- if MODELS_LOADED[model_key] is None or MODELS_LOADED[tokenizer_key] is None:
59
- print(f"Attempting to load {model_kind} model: {model_id_to_load} (Type: {type(model_id_to_load)})")
60
- if not model_id_to_load or not isinstance(model_id_to_load, str):
61
- MODELS_LOADED[error_key] = f"Invalid model ID: {model_id_to_load}"
62
- yield f"Error: {model_kind} model ID is not configured correctly ({model_id_to_load})."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return
64
- try:
65
- tokenizer = AutoTokenizer.from_pretrained(model_id_to_load, trust_remote_code=True)
66
- # On ZeroGPU, device_map="auto" should leverage the @spaces.GPU context
67
- model = AutoModelForCausalLM.from_pretrained(
68
- model_id_to_load,
69
- torch_dtype=torch.bfloat16, # Qwen models often prefer bfloat16
70
- device_map="auto",
71
- trust_remote_code=True,
72
- )
73
- model.eval()
74
-
75
- if tokenizer.pad_token is None:
76
- tokenizer.pad_token = tokenizer.eos_token
77
- if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
78
- tokenizer.pad_token_id = tokenizer.eos_token_id
79
-
80
- MODELS_LOADED[model_key] = model
81
- MODELS_LOADED[tokenizer_key] = tokenizer
82
- print(f"Successfully loaded and cached {model_kind} model and tokenizer.")
83
- except Exception as e:
84
- MODELS_LOADED[error_key] = str(e)
85
- print(f"ERROR loading {model_kind} model ({model_id_to_load}): {e}")
86
- yield f"Error loading {model_kind} model: {e}" # Yield error to Gradio
87
- return # Stop further execution for this call
88
-
89
- # Retrieve from cache
90
- model = MODELS_LOADED[model_key]
91
- tokenizer = MODELS_LOADED[tokenizer_key]
92
 
93
- if model is None or tokenizer is None: # Should not happen if loading was successful
94
- yield f"Model or tokenizer for {model_kind} is unexpectedly None after loading attempt."
95
  return
96
 
97
  # Prepare conversation
98
- conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
 
 
 
 
99
  conversation.extend(chat_history)
100
  conversation.append({"role": "user", "content": message})
101
 
 
102
  prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
103
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
104
 
 
105
  eos_tokens_ids = [tokenizer.eos_token_id]
106
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
107
- if im_end_id != getattr(tokenizer, 'unk_token_id', None) and im_end_id not in eos_tokens_ids:
108
  eos_tokens_ids.append(im_end_id)
109
- eos_tokens_ids = list(set(eos_tokens_ids)) # Remove duplicates
110
 
111
- try:
 
112
  generated_token_ids = model.generate(
113
- **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
114
- repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
 
 
 
 
 
 
115
  )
116
- new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
117
- response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
118
-
119
- full_response = ""
120
- for char_idx, char_val in enumerate(response_text):
121
- full_response += char_val
122
- # Yield more slowly or in chunks if char-by-char is too slow/frequent for Gradio
123
- if char_idx % 5 == 0 or char_idx == len(response_text) -1 : # Yield every 5 chars or at the end
124
- time.sleep(0.001) # Minimal sleep
125
- yield full_response
126
- if not response_text: # Handle empty generation
127
- yield ""
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
129
  except Exception as e:
130
- print(f"Error during {model_kind} model generation: {e}")
131
- yield f"Error during generation: {e}"
132
-
133
-
134
- # --- Gradio Event Handler Wrappers (these get decorated) ---
135
- def create_gpu_handler(model_id, system_prompt, model_kind_str):
136
- # This function will be decorated by @spaces.GPU
137
- # It calls the actual logic.
138
- def gpu_fn(message, chat_history):
139
- yield from _load_and_infer(message, chat_history, model_id, system_prompt, model_kind_str)
140
- return gpu_fn
141
-
142
- # Apply the decorator IF `spaces` module was imported and has `GPU`
143
- if spaces and hasattr(spaces, "GPU"):
144
- print("Applying @spaces.GPU decorator.")
145
- base_model_predict = spaces.GPU(create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base"))
146
- ft_model_predict = spaces.GPU(create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft"))
147
- else:
148
- print("WARNING: @spaces.GPU decorator not applied. GPU acceleration on ZeroGPU might not work as expected.")
149
- # Fallback to non-decorated calls; this will likely lead to "No @spaces.GPU function detected"
150
- # or CUDA errors if running on ZeroGPU that expects the decorator.
151
- base_model_predict = create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base")
152
- ft_model_predict = create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft")
153
 
 
 
 
 
 
 
 
154
 
155
  # --- Gradio UI Definition ---
156
- with gr.Blocks(theme=gr.themes.Default()) as demo: # Changed to Default theme, Soft can sometimes have issues
157
  gr.Markdown(
158
  f"""
159
- # 🎬 CineGuide vs. Base {BASE_MODEL_ID}
160
- Compare the fine-tuned CineGuide (`{FINETUNED_MODEL_ID}`) with the base {BASE_MODEL_ID}.
161
- **Note:** Models are loaded on first use within a GPU context and may take time.
162
- This Space attempts to use the ZeroGPU shared pool via `@spaces.GPU`.
 
 
 
 
 
163
  """
164
  )
 
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
- gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
168
- chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  with gr.Column(scale=1):
170
- gr.Markdown(f"## 🤖 Fine-tuned CineGuide")
171
- chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages")
172
-
173
- with gr.Row():
174
- shared_input_textbox = gr.Textbox(
175
- show_label=False, placeholder="Enter your movie query...", container=False, scale=7
176
- )
177
- submit_button = gr.Button("✉️ Send", variant="primary", scale=1)
178
-
179
- gr.Examples(
180
- examples=[
181
- "Hi! I'm looking for something funny to watch tonight.",
182
- "I love dry, witty humor more than slapstick.",
183
- "I'm really into complex sci-fi movies that make you think.",
184
- "Tell me about some good action movies from the 90s.",
185
- "Recommend a thought-provoking sci-fi film about AI.",
186
- ],
187
- inputs=[shared_input_textbox], label="Example Prompts"
188
- )
189
-
190
- # Event handling
191
- # The `base_model_predict` and `ft_model_predict` are now the (potentially) decorated functions.
192
- submit_button.click(
193
- base_model_predict,
194
- [shared_input_textbox, chatbot_base],
195
- [chatbot_base],
196
- api_name="base_predict" # Good for testing API route
197
- )
198
- submit_button.click(
199
- ft_model_predict,
200
- [shared_input_textbox, chatbot_ft],
201
- [chatbot_ft],
202
- api_name="ft_predict"
203
- )
204
- shared_input_textbox.submit(
205
- base_model_predict,
206
- [shared_input_textbox, chatbot_base],
207
- [chatbot_base]
208
- )
209
- shared_input_textbox.submit(
210
- ft_model_predict,
211
- [shared_input_textbox, chatbot_ft],
212
- [chatbot_ft]
213
- )
214
-
215
- def clear_textbox_fn(): return ""
216
- submit_button.click(clear_textbox_fn, [], [shared_input_textbox], queue=False) # queue=False for instant clear
217
- shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox], queue=False)
218
-
219
 
220
  if __name__ == "__main__":
221
- demo.queue() # Enable queuing for multiple users
222
- # debug=True can sometimes interfere with production Spaces, but fine for testing
223
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import time
5
+ import spaces
 
 
 
 
 
 
 
 
 
6
 
7
  # --- Configuration ---
8
  BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
9
+ FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"
10
 
11
  SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
12
  1. Provide personalized movie recommendations based on user preferences
13
  2. Give brief, compelling rationales for why you recommend each movie
14
  3. Ask thoughtful follow-up questions to better understand user tastes
15
  4. Maintain an enthusiastic but not overwhelming tone about cinema
16
+
17
  When recommending movies, always explain WHY the movie fits their preferences."""
18
  SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
19
 
20
+ # --- Global Model Cache ---
21
+ _models_cache = {
22
+ "base": None,
23
+ "finetuned": None,
24
+ "tokenizer_base": None,
25
+ "tokenizer_ft": None,
 
 
 
 
 
 
26
  }
27
 
28
+ def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
29
+ """Loads a model and tokenizer if not already in cache."""
30
+ if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
31
+ print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
32
+ return _models_cache[model_key], _models_cache[tokenizer_key]
 
 
 
 
 
 
 
 
 
33
 
34
+ print(f"Loading {model_key} model ({model_identifier})...")
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_identifier,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ trust_remote_code=True,
42
+ )
43
+ model.eval()
44
+
45
+ if tokenizer.pad_token is None:
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+ if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
48
+ tokenizer.pad_token_id = tokenizer.eos_token_id
49
+
50
+ _models_cache[model_key] = model
51
+ _models_cache[tokenizer_key] = tokenizer
52
+ print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
53
+ return model, tokenizer
54
+ except Exception as e:
55
+ print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
56
+ _models_cache[model_key] = "error"
57
+ _models_cache[tokenizer_key] = "error"
58
+ raise
59
+
60
+ def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
61
+ """Generate response using specified model type."""
62
+ model, tokenizer = None, None
63
+ system_prompt = ""
64
+
65
+ if model_type_to_load == "base":
66
+ if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
67
+ yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
68
  return
69
+ model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
70
+ system_prompt = SYSTEM_PROMPT_BASE
71
+ elif model_type_to_load == "finetuned":
72
+ if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
73
+ print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
74
+ yield "Error: Fine-tuned model ID is not configured correctly."
75
+ return
76
+ if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
77
+ yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
78
+ return
79
+ model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
80
+ system_prompt = SYSTEM_PROMPT_CINEGUIDE
81
+ else:
82
+ yield "Invalid model type."
83
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ if model is None or tokenizer is None:
86
+ yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
87
  return
88
 
89
  # Prepare conversation
90
+ conversation = []
91
+ if system_prompt:
92
+ conversation.append({"role": "system", "content": system_prompt})
93
+
94
+ # Add chat history
95
  conversation.extend(chat_history)
96
  conversation.append({"role": "user", "content": message})
97
 
98
+ # Generate response
99
  prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
100
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
101
 
102
+ # Prepare EOS tokens
103
  eos_tokens_ids = [tokenizer.eos_token_id]
104
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
105
+ if im_end_id != getattr(tokenizer, 'unk_token_id', None):
106
  eos_tokens_ids.append(im_end_id)
107
+ eos_tokens_ids = list(set(eos_tokens_ids))
108
 
109
+ # Generate
110
+ with torch.no_grad():
111
  generated_token_ids = model.generate(
112
+ **inputs,
113
+ max_new_tokens=512,
114
+ do_sample=True,
115
+ temperature=0.7,
116
+ top_p=0.9,
117
+ repetition_penalty=1.1,
118
+ pad_token_id=tokenizer.pad_token_id,
119
+ eos_token_id=eos_tokens_ids
120
  )
121
+
122
+ new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
123
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
124
+
125
+ # Stream the response
126
+ full_response = ""
127
+ for char in response_text:
128
+ full_response += char
129
+ time.sleep(0.005)
130
+ yield full_response
131
+
132
+ @spaces.GPU
133
+ def base_model_predict(user_message, chat_history):
134
+ """Predict using base model - decorated with @spaces.GPU."""
135
+ try:
136
+ bot_response_stream = generate_chat_response(user_message, chat_history, "base")
137
+ for chunk in bot_response_stream:
138
+ yield chunk
139
+ except Exception as e:
140
+ print(f"Error in base_model_predict: {e}")
141
+ yield f"Error generating base model response: {str(e)}"
142
 
143
+ @spaces.GPU
144
+ def ft_model_predict(user_message, chat_history):
145
+ """Predict using fine-tuned model - decorated with @spaces.GPU."""
146
+ try:
147
+ bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
148
+ for chunk in bot_response_stream:
149
+ yield chunk
150
  except Exception as e:
151
+ print(f"Error in ft_model_predict: {e}")
152
+ yield f"Error generating fine-tuned response: {str(e)}"
153
+
154
+ def format_chat_history(history, message):
155
+ """Format the chat history for the models."""
156
+ formatted_history = []
157
+ for chat in history:
158
+ if isinstance(chat, dict) and 'role' in chat:
159
+ formatted_history.append(chat)
160
+ elif isinstance(chat, list) and len(chat) == 2:
161
+ formatted_history.extend([
162
+ {"role": "user", "content": chat[0]},
163
+ {"role": "assistant", "content": chat[1]}
164
+ ])
165
+ return formatted_history
166
+
167
+ def respond_base(message, history):
168
+ """Handle base model response for Gradio ChatInterface."""
169
+ formatted_history = format_chat_history(history, message)
170
+ response_gen = base_model_predict(message, formatted_history)
171
+
172
+ for response in response_gen:
173
+ yield response
174
 
175
+ def respond_ft(message, history):
176
+ """Handle fine-tuned model response for Gradio ChatInterface."""
177
+ formatted_history = format_chat_history(history, message)
178
+ response_gen = ft_model_predict(message, formatted_history)
179
+
180
+ for response in response_gen:
181
+ yield response
182
 
183
  # --- Gradio UI Definition ---
184
+ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
185
  gr.Markdown(
186
  f"""
187
+ # 🎬 CineGuide vs. Base Model Comparison
188
+ Compare the fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
189
+
190
+ **Base Model:** `{BASE_MODEL_ID}`
191
+ **Fine-tuned Model:** `{FINETUNED_MODEL_ID}`
192
+
193
+ Type your movie-related query below and see how each model responds!
194
+
195
+ ⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
196
  """
197
  )
198
+
199
  with gr.Row():
200
  with gr.Column(scale=1):
201
+ gr.Markdown(f"## 🗣️ Base Model")
202
+ gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
203
+ chatbot_base = gr.ChatInterface(
204
+ respond_base,
205
+ textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
206
+ title="",
207
+ description="",
208
+ theme="soft",
209
+ examples=[
210
+ "Hi! I'm looking for something funny to watch tonight.",
211
+ "I love dry, witty humor more than slapstick.",
212
+ "I'm really into complex sci-fi movies that make you think.",
213
+ "Can you recommend a good thriller?",
214
+ "What's a good romantic comedy from the 2000s?"
215
+ ],
216
+ cache_examples=False,
217
+ retry_btn=None,
218
+ undo_btn="⤴️ Undo",
219
+ clear_btn="🗑️ Clear"
220
+ )
221
+
222
  with gr.Column(scale=1):
223
+ gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
224
+ gr.Markdown(f"*Specialized for movie recommendations*")
225
+ chatbot_ft = gr.ChatInterface(
226
+ respond_ft,
227
+ textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
228
+ title="",
229
+ description="",
230
+ theme="soft",
231
+ examples=[
232
+ "Hi! I'm looking for something funny to watch tonight.",
233
+ "I love dry, witty humor more than slapstick.",
234
+ "I'm really into complex sci-fi movies that make you think.",
235
+ "Can you recommend a good thriller?",
236
+ "What's a good romantic comedy from the 2000s?"
237
+ ],
238
+ cache_examples=False,
239
+ retry_btn=None,
240
+ undo_btn="⤴️ Undo",
241
+ clear_btn="🗑️ Clear"
242
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  if __name__ == "__main__":
245
+ demo.queue(max_size=20)
246
+ demo.launch()