cella110n commited on
Commit
941802a
·
verified ·
1 Parent(s): e0ed6cc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -24
app.py CHANGED
@@ -15,8 +15,8 @@ import time
15
  import spaces # Required for @spaces.GPU
16
 
17
  import torch # Keep torch for device check in Tagger
18
- # import timm # No model loading yet
19
- # from safetensors.torch import load_file as safe_load_file # No model loading yet
20
 
21
  # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
22
  # matplotlib.use('Agg')
@@ -67,7 +67,7 @@ SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
67
  METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
68
  TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
69
  CACHE_DIR = "./model_cache"
70
- # BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # No model loading yet
71
 
72
  # --- Tagger Class ---
73
  class Tagger:
@@ -155,21 +155,83 @@ class Tagger:
155
  else:
156
  print("Labels already loaded.")
157
 
158
- # Add a simple test method decorated with GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @spaces.GPU()
160
- def test_gpu_method(self):
161
- current_time = time.time()
162
- print(f"--- Tagger.test_gpu_method called on GPU worker at {current_time} ---")
163
- # Check if labels are accessible from the GPU worker context
164
- label_count = len(self.labels_data.names) if self.labels_data else -1
165
- print(f"--- (Worker) Label count: {label_count} ---")
166
- return f"Tagger method called at {current_time}. Label count: {label_count}"
167
-
168
- # --- Original predict_on_gpu (Keep commented out for this test) ---
169
- # @spaces.GPU()
170
- # def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
171
- # # ... (original prediction logic including model loading) ...
172
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # Instantiate the tagger class (this will download files/load labels)
175
  tagger = Tagger()
@@ -177,18 +239,28 @@ tagger = Tagger()
177
  # --- Gradio Interface Definition (Minimal) ---
178
  with gr.Blocks() as demo:
179
  gr.Markdown("""
180
- # Tagger Initialization + Minimal Button Test
181
- Instantiates Tagger, then click the button below to check if a simple `@spaces.GPU` decorated *method* is triggered.
182
- Check logs for Tagger initialization messages.
183
  """)
184
  with gr.Column():
185
- test_button = gr.Button("Test Tagger GPU Method")
 
186
  output_text = gr.Textbox(label="Output")
 
 
 
 
 
 
 
187
 
188
  test_button.click(
189
- fn=tagger.test_gpu_method, # Call the simple method on the instance
190
- inputs=[],
191
- outputs=[output_text]
 
 
192
  )
193
 
194
  # --- Main Block ---
 
15
  import spaces # Required for @spaces.GPU
16
 
17
  import torch # Keep torch for device check in Tagger
18
+ import timm # Restore timm
19
+ from safetensors.torch import load_file as safe_load_file # Restore safetensors loading
20
 
21
  # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
22
  # matplotlib.use('Agg')
 
67
  METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
68
  TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
69
  CACHE_DIR = "./model_cache"
70
+ BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # Restore base model name
71
 
72
  # --- Tagger Class ---
73
  class Tagger:
 
155
  else:
156
  print("Labels already loaded.")
157
 
158
+ # Restore model loading function
159
+ def _load_model_on_gpu(self):
160
+ # Only load if not already loaded on the correct device
161
+ if self.model is not None and next(self.model.parameters()).device == self.device:
162
+ print("Model already loaded on the correct device.")
163
+ return True # Indicate success
164
+
165
+ print("Loading PyTorch model for GPU worker...")
166
+ if not self.safetensors_path or not self.labels_data:
167
+ print("Error: Model paths or labels not initialized before loading.")
168
+ return False # Indicate failure
169
+ try:
170
+ num_classes = len(self.labels_data.names)
171
+ if num_classes <= 0: raise ValueError(f"Invalid num_classes: {num_classes}")
172
+ print(f"Creating base model: {BASE_MODEL_NAME} with {num_classes} classes")
173
+ # Load model structure (without pretrained weights initially if possible, or handle mismatch)
174
+ # Using pretrained=True might download weights we immediately overwrite
175
+ model = timm.create_model(BASE_MODEL_NAME, pretrained=True, num_classes=num_classes)
176
+
177
+ print(f"Loading state dict from: {self.safetensors_path}")
178
+ if not os.path.exists(self.safetensors_path): raise FileNotFoundError(f"File not found: {self.safetensors_path}")
179
+ state_dict = safe_load_file(self.safetensors_path)
180
+
181
+ # --- Key Adaptation Logic (Important!) ---
182
+ # Assuming direct match based on previous code structure
183
+ adapted_state_dict = state_dict
184
+ # Example if keys were prefixed with 'base_model.':
185
+ # adapted_state_dict = {k.replace('base_model.', ''): v for k, v in state_dict.items()}
186
+ # -----------------------------------------
187
+
188
+ print("Loading state dict into model...")
189
+ missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
190
+ # Only print if there are actually missing/unexpected keys
191
+ if missing_keys: print(f"State dict loaded. Missing keys: {missing_keys}")
192
+ if unexpected_keys: print(f"State dict loaded. Unexpected keys: {unexpected_keys}")
193
+ if any(k.startswith('head.') for k in missing_keys): print("Warning: Head weights seem missing/mismatched!")
194
+
195
+ print(f"Moving model to device: {self.device}")
196
+ model.to(self.device)
197
+ model.eval()
198
+ self.model = model # Store loaded model
199
+ print("Model loaded successfully on GPU worker.")
200
+ return True # Indicate success
201
+ except Exception as e:
202
+ print(f"(Worker) Error loading PyTorch model: {e}")
203
+ import traceback; print(traceback.format_exc())
204
+ # raise gr.Error(f"Error loading PyTorch model: {e}") # Don't raise here, return status
205
+ return False # Indicate failure
206
+
207
+ # Restore predict_on_gpu, but modify it to ONLY test model loading
208
  @spaces.GPU()
209
+ def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
210
+ print("--- predict_on_gpu function started (GPU worker - TESTING MODEL LOAD) ---")
211
+
212
+ # Attempt to load the model
213
+ load_success = self._load_model_on_gpu()
214
+
215
+ if load_success:
216
+ message = "Model loading successful on GPU worker."
217
+ print(message)
218
+ # Optional: Check model device again after loading
219
+ if self.model is not None:
220
+ print(f"Model device after load: {next(self.model.parameters()).device}")
221
+ else:
222
+ print("Model object is None even after successful load reported?")
223
+ else:
224
+ message = "Error: Model could not be loaded on GPU worker. Check logs."
225
+ print(message)
226
+
227
+ # Return only the status message for this test, and None for the image output
228
+ return message, None
229
+
230
+ # --- Original prediction logic (commented out for this test) ---
231
+ # if self.model is None:
232
+ # return "Error: Model could not be loaded on GPU worker.", None
233
+ # if image_input is None: return "Please upload an image.", None
234
+ # ... (image loading, preprocessing, inference, postprocessing) ...
235
 
236
  # Instantiate the tagger class (this will download files/load labels)
237
  tagger = Tagger()
 
239
  # --- Gradio Interface Definition (Minimal) ---
240
  with gr.Blocks() as demo:
241
  gr.Markdown("""
242
+ # Tagger Initialization + Model Load Test
243
+ Instantiates Tagger, then click the button below to attempt loading the model via `@spaces.GPU`.
244
+ Check logs for Tagger initialization and model loading messages.
245
  """)
246
  with gr.Column():
247
+ # Keep using the same button name for simplicity for now
248
+ test_button = gr.Button("Test Model Load on GPU")
249
  output_text = gr.Textbox(label="Output")
250
+ # Add dummy components to match the signature of the real predict_on_gpu eventually
251
+ # These won't be used by the button click directly but might be needed if we switch fn later
252
+ dummy_image = gr.Image(visible=False) # Hidden image input
253
+ dummy_gen_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.55, visible=False)
254
+ dummy_char_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.60, visible=False)
255
+ dummy_radio = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", visible=False)
256
+ dummy_vis_output = gr.Image(visible=False) # Hidden image output
257
 
258
  test_button.click(
259
+ fn=tagger.predict_on_gpu,
260
+ # Provide dummy inputs matching the function signature
261
+ # We only care about the first output (text) for this test
262
+ inputs=[dummy_image, dummy_gen_slider, dummy_char_slider, dummy_radio],
263
+ outputs=[output_text, dummy_vis_output] # Map outputs
264
  )
265
 
266
  # --- Main Block ---