cella110n commited on
Commit
723a482
·
verified ·
1 Parent(s): 9ea82e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +573 -573
app.py CHANGED
@@ -1,573 +1,573 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
4
- # from PIL import Image, ImageDraw, ImageFont # No drawing yet
5
- import json
6
- import os
7
- import io
8
- import requests
9
- import matplotlib.pyplot as plt # For visualization
10
- import matplotlib # For backend setting
11
- from huggingface_hub import hf_hub_download
12
- from dataclasses import dataclass
13
- from typing import List, Dict, Optional, Tuple
14
- import time
15
- import spaces # Required for @spaces.GPU
16
- import onnxruntime as ort # Use ONNX Runtime
17
-
18
- import torch # Keep torch for device check in Tagger
19
- import timm # Restore timm
20
- from safetensors.torch import load_file as safe_load_file # Restore safetensors loading
21
-
22
- # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
23
- # matplotlib.use('Agg')
24
-
25
- # --- Data Classes and Helper Functions ---
26
- @dataclass
27
- class LabelData:
28
- names: list[str]
29
- rating: list[np.int64]
30
- general: list[np.int64]
31
- artist: list[np.int64]
32
- character: list[np.int64]
33
- copyright: list[np.int64]
34
- meta: list[np.int64]
35
- quality: list[np.int64]
36
- model: list[np.int64]
37
-
38
- def pil_ensure_rgb(image: Image.Image) -> Image.Image:
39
- if image.mode not in ["RGB", "RGBA"]:
40
- image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
41
- if image.mode == "RGBA":
42
- background = Image.new("RGB", image.size, (255, 255, 255))
43
- background.paste(image, mask=image.split()[3])
44
- image = background
45
- return image
46
-
47
- def pil_pad_square(image: Image.Image) -> Image.Image:
48
- width, height = image.size
49
- if width == height: return image
50
- new_size = max(width, height)
51
- new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
52
- paste_position = ((new_size - width) // 2, (new_size - height) // 2)
53
- new_image.paste(image, paste_position)
54
- return new_image
55
-
56
- def load_tag_mapping(mapping_path):
57
- # Use the implementation from the original app.py as it was confirmed working
58
- with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
59
- # Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
60
- if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
61
- idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
62
- tag_to_category = tag_mapping_data["tag_to_category"]
63
- elif isinstance(tag_mapping_data, dict):
64
- # Assuming the dict-of-dicts format from previous tests
65
- try:
66
- tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
67
- idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
68
- tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
69
- except (KeyError, ValueError) as e:
70
- raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
71
- else:
72
- raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
73
-
74
- names = [None] * (max(idx_to_tag.keys()) + 1)
75
- rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
76
- for idx, tag in idx_to_tag.items():
77
- if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
78
- names[idx] = tag
79
- category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
80
- idx_int = int(idx)
81
- if category == 'Rating': rating.append(idx_int)
82
- elif category == 'General': general.append(idx_int)
83
- elif category == 'Artist': artist.append(idx_int)
84
- elif category == 'Character': character.append(idx_int)
85
- elif category == 'Copyright': copyright.append(idx_int)
86
- elif category == 'Meta': meta.append(idx_int)
87
- elif category == 'Quality': quality.append(idx_int)
88
- elif category == 'Model': model_name.append(idx_int)
89
-
90
- return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
91
- character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
92
-
93
- def preprocess_image(image: Image.Image, target_size=(448, 448)):
94
- # Adapted from onnx_predict.py's version
95
- image = pil_ensure_rgb(image)
96
- image = pil_pad_square(image)
97
- image_resized = image.resize(target_size, Image.BICUBIC)
98
- img_array = np.array(image_resized, dtype=np.float32) / 255.0
99
- img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
100
- # Assuming model expects RGB based on original code, no BGR conversion here
101
- img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
102
- mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
103
- std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
104
- img_array = (img_array - mean) / std
105
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
106
- return image, img_array
107
-
108
- # Add get_tags function (from onnx_predict.py)
109
- def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
110
- result = {
111
- "rating": [],
112
- "general": [],
113
- "character": [],
114
- "copyright": [],
115
- "artist": [],
116
- "meta": [],
117
- "quality": [],
118
- "model": []
119
- }
120
- # Rating (select max)
121
- if len(labels.rating) > 0:
122
- # Ensure indices are within bounds
123
- valid_indices = labels.rating[labels.rating < len(probs)]
124
- if len(valid_indices) > 0:
125
- rating_probs = probs[valid_indices]
126
- if len(rating_probs) > 0:
127
- rating_idx_local = np.argmax(rating_probs)
128
- rating_idx_global = valid_indices[rating_idx_local]
129
- # Check if global index is valid for names list
130
- if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
131
- rating_name = labels.names[rating_idx_global]
132
- rating_conf = float(rating_probs[rating_idx_local])
133
- result["rating"].append((rating_name, rating_conf))
134
- else:
135
- print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
136
- else:
137
- print("Warning: rating_probs became empty after filtering.")
138
- else:
139
- print("Warning: No valid indices found for rating tags within probs length.")
140
-
141
- # Quality (select max)
142
- if len(labels.quality) > 0:
143
- valid_indices = labels.quality[labels.quality < len(probs)]
144
- if len(valid_indices) > 0:
145
- quality_probs = probs[valid_indices]
146
- if len(quality_probs) > 0:
147
- quality_idx_local = np.argmax(quality_probs)
148
- quality_idx_global = valid_indices[quality_idx_local]
149
- if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
150
- quality_name = labels.names[quality_idx_global]
151
- quality_conf = float(quality_probs[quality_idx_local])
152
- result["quality"].append((quality_name, quality_conf))
153
- else:
154
- print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
155
- else:
156
- print("Warning: quality_probs became empty after filtering.")
157
- else:
158
- print("Warning: No valid indices found for quality tags within probs length.")
159
-
160
- # Threshold-based categories
161
- category_map = {
162
- "general": (labels.general, gen_threshold),
163
- "character": (labels.character, char_threshold),
164
- "copyright": (labels.copyright, char_threshold),
165
- "artist": (labels.artist, char_threshold),
166
- "meta": (labels.meta, gen_threshold),
167
- "model": (labels.model, gen_threshold)
168
- }
169
- for category, (indices, threshold) in category_map.items():
170
- if len(indices) > 0:
171
- valid_indices = indices[(indices < len(probs))] # Check index bounds first
172
- if len(valid_indices) > 0:
173
- category_probs = probs[valid_indices]
174
- mask = category_probs >= threshold
175
- selected_indices_local = np.where(mask)[0]
176
- if len(selected_indices_local) > 0:
177
- selected_indices_global = valid_indices[selected_indices_local]
178
- selected_probs = category_probs[selected_indices_local]
179
- for idx_global, prob_val in zip(selected_indices_global, selected_probs):
180
- # Check if global index is valid for names list
181
- if idx_global < len(labels.names) and labels.names[idx_global] is not None:
182
- result[category].append((labels.names[idx_global], float(prob_val)))
183
- else:
184
- print(f"Warning: Invalid global index {idx_global} for {category} tag.")
185
- # else: print(f"No tags found for category '{category}' above threshold {threshold}")
186
- # else: print(f"No valid indices found for category '{category}' within probs length.")
187
- # else: print(f"No indices defined for category '{category}'")
188
-
189
- # Sort by probability (descending)
190
- for k in result:
191
- result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
192
- return result
193
-
194
- # Add visualize_predictions function (Adapted from onnx_predict.py and previous versions)
195
- def visualize_predictions(image: Image.Image, predictions: Dict, threshold: float):
196
- # Filter out unwanted meta tags (e.g., id, commentary, request, mismatch)
197
- filtered_meta = []
198
- excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
199
- for tag, prob in predictions.get("meta", []):
200
- if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
201
- filtered_meta.append((tag, prob))
202
- predictions["meta"] = filtered_meta # Use filtered list for visualization
203
-
204
- # --- Plotting Setup ---
205
- plt.rcParams['font.family'] = 'DejaVu Sans'
206
- fig = plt.figure(figsize=(8, 12), dpi=100)
207
- ax_tags = fig.add_subplot(1, 1, 1)
208
-
209
- all_tags, all_probs, all_colors = [], [], []
210
- color_map = {
211
- 'rating': 'red', 'character': 'blue', 'copyright': 'purple',
212
- 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow', 'model': 'cyan'
213
- }
214
-
215
- # Aggregate tags from predictions dictionary
216
- for cat, prefix, color in [
217
- ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
218
- ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
219
- ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
220
- ('meta', 'M', color_map['meta']), ('model', 'M', color_map['model'])
221
- ]:
222
- sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
223
- for tag, prob in sorted_tags:
224
- all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
225
- all_probs.append(prob)
226
- all_colors.append(color)
227
-
228
- if not all_tags:
229
- ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
230
- ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
231
- ax_tags.axis('off')
232
- else:
233
- sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
234
- all_tags = [all_tags[i] for i in sorted_indices]
235
- all_probs = [all_probs[i] for i in sorted_indices]
236
- all_colors = [all_colors[i] for i in sorted_indices]
237
-
238
- num_tags = len(all_tags)
239
- bar_height = min(0.8, max(0.1, 0.8 * (30 / num_tags))) if num_tags > 30 else 0.8
240
- y_positions = np.arange(num_tags)
241
-
242
- bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
243
- ax_tags.set_yticks(y_positions)
244
- ax_tags.set_yticklabels(all_tags)
245
-
246
- fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
247
- for lbl in ax_tags.get_yticklabels():
248
- lbl.set_fontsize(fontsize)
249
-
250
- for i, (bar, prob) in enumerate(zip(bars, all_probs)):
251
- text_x = min(prob + 0.02, 0.98)
252
- ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
253
-
254
- ax_tags.set_xlim(0, 1)
255
- ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
256
-
257
- from matplotlib.patches import Patch
258
- legend_elements = [
259
- Patch(facecolor=color, label=cat.capitalize())
260
- for cat, color in color_map.items()
261
- if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
262
- ]
263
- if legend_elements:
264
- ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
265
-
266
- plt.tight_layout()
267
- buf = io.BytesIO()
268
- plt.savefig(buf, format='png', dpi=100)
269
- plt.close(fig)
270
- buf.seek(0)
271
- return Image.open(buf)
272
-
273
- # --- Constants ---
274
- REPO_ID = "celstk/wd-eva02-lora-onnx"
275
- # Model options
276
- MODEL_OPTIONS = {
277
- "cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
278
- # "cl_eva02_tagger_v1_250502": "cl_eva02_tagger_v1_250503/model.onnx",
279
- # "cl_eva02_tagger_v1_250509": "cl_eva02_tagger_v1_250509/model.onnx",
280
- "cl_eva02_tagger_v1_250511": "cl_eva02_tagger_v1_250511/model.onnx",
281
- # "cl_eva02_tagger_v1_250512": "cl_eva02_tagger_v1_250512/model.onnx",
282
- # "cl_eva02_tagger_v1_250513": "cl_eva02_tagger_v1_250513/model.onnx",
283
- # "cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
284
- # "cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
285
- # "cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
286
- # "cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx",
287
- # "cl_eva02_tagger_v1_250523": "cl_eva02_tagger_v1_250523/model.onnx",
288
- # "cl_eva02_tagger_v1_250702": "cl_eva02_tagger_v1_250702/model.onnx",
289
- # "cl_eva02_tagger_v1_250704": "cl_eva02_tagger_v1_250704/model.onnx",
290
- # "cl_eva02_tagger_v1_250705": "cl_eva02_tagger_v1_250705/model.onnx",
291
- # "cl_eva02_tagger_v1_250706": "cl_eva02_tagger_v1_250706/model.onnx",
292
- "cl_eva02_tagger_v1_250707": "cl_eva02_tagger_v1_250707/model.onnx",
293
- # # "cl_eva02_tagger_v1_250708": "cl_eva02_tagger_v1_250708/model.onnx",
294
- # "cl_eva02_tagger_v1_250717": "cl_eva02_tagger_v1_250717/model.onnx",
295
- # "cl_eva02_tagger_v1_250719": "cl_eva02_tagger_v1_250719/model.onnx",
296
- # "cl_eva02_tagger_v1_250720": "cl_eva02_tagger_v1_250720/model.onnx",
297
- # "cl_eva02_tagger_v1_250725": "cl_eva02_tagger_v1_250725/model.onnx",
298
- # "cl_eva02_tagger_v1_250729": "cl_eva02_tagger_v1_250729/model.onnx",
299
- # "cl_eva02_tagger_v1_250801": "cl_eva02_tagger_v1_250801/model.onnx",
300
- "cl_eva02_tagger_v1_250804": "cl_eva02_tagger_v1_250804/model.onnx",
301
- "cl_eva02_tagger_v1_250807": "cl_eva02_tagger_v1_250807/model.onnx"
302
- }
303
- DEFAULT_MODEL = "cl_eva02_tagger_v1_250807"
304
- CACHE_DIR = "./model_cache"
305
-
306
- # --- Global variables for paths (initialized at startup) ---
307
- g_onnx_model_path = None
308
- g_tag_mapping_path = None
309
- g_labels_data = None
310
- g_idx_to_tag = None
311
- g_tag_to_category = None
312
- g_current_model = None
313
-
314
- # --- Initialization Function ---
315
- def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
316
- global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
317
-
318
- if not model_choice in MODEL_OPTIONS:
319
- print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
320
- model_choice = DEFAULT_MODEL
321
-
322
- g_current_model = model_choice
323
- model_dir = model_choice
324
- onnx_filename = MODEL_OPTIONS[model_choice]
325
- tag_mapping_filename = f"{model_dir}/tag_mapping.json"
326
-
327
- print(f"Initializing ONNX paths and labels for model: {model_choice}...")
328
- hf_token = os.environ.get("HF_TOKEN")
329
- try:
330
- print(f"Attempting to download ONNX model: {onnx_filename}")
331
- g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
332
- print(f"ONNX model path: {g_onnx_model_path}")
333
-
334
- print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
335
- g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
336
- print(f"Tag mapping path: {g_tag_mapping_path}")
337
-
338
- print("Loading labels from mapping...")
339
- g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
340
- print(f"Labels loaded. Count: {len(g_labels_data.names)}")
341
-
342
- return True
343
-
344
- except Exception as e:
345
- print(f"Error during initialization: {e}")
346
- import traceback; traceback.print_exc()
347
- # Reset globals to force reinitialization
348
- g_onnx_model_path = None
349
- g_tag_mapping_path = None
350
- g_labels_data = None
351
- g_idx_to_tag = None
352
- g_tag_to_category = None
353
- g_current_model = None
354
- # Raise Gradio error to make it visible in the UI
355
- raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
356
-
357
- # Function to handle model change
358
- def change_model(model_choice):
359
- try:
360
- success = initialize_onnx_paths(model_choice)
361
- if success:
362
- return f"Model changed to: {model_choice}"
363
- else:
364
- return "Failed to change model. See logs for details."
365
- except Exception as e:
366
- return f"Error changing model: {str(e)}"
367
-
368
- # --- Main Prediction Function (ONNX) ---
369
- @spaces.GPU()
370
- def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
371
- print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
372
-
373
- # Ensure current model matches selected model
374
- global g_current_model
375
- if g_current_model != model_choice:
376
- print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
377
- try:
378
- initialize_onnx_paths(model_choice)
379
- except Exception as e:
380
- return f"Error initializing model '{model_choice}': {str(e)}", None
381
-
382
- # --- 1. Ensure paths and labels are loaded ---
383
- if g_onnx_model_path is None or g_labels_data is None:
384
- message = "Error: Paths or labels not initialized. Check startup logs."
385
- print(message)
386
- # Return error message and None for the image output
387
- return message, None
388
-
389
- # --- 2. Load ONNX Session (inside worker) ---
390
- session = None
391
- try:
392
- print(f"Loading ONNX session from: {g_onnx_model_path}")
393
- available_providers = ort.get_available_providers()
394
- providers = []
395
- if 'CUDAExecutionProvider' in available_providers:
396
- providers.append('CUDAExecutionProvider')
397
- providers.append('CPUExecutionProvider')
398
- print(f"Attempting to load session with providers: {providers}")
399
- session = ort.InferenceSession(g_onnx_model_path, providers=providers)
400
- print(f"ONNX session loaded using: {session.get_providers()[0]}")
401
- except Exception as e:
402
- message = f"Error loading ONNX session in worker: {e}"
403
- print(message)
404
- import traceback; traceback.print_exc()
405
- return message, None
406
-
407
- # --- 3. Process Input Image ---
408
- if image_input is None:
409
- return "Please upload an image.", None
410
-
411
- print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
412
- try:
413
- # Handle different input types (PIL, numpy, URL, file path)
414
- if isinstance(image_input, str):
415
- if image_input.startswith("http"): # URL
416
- response = requests.get(image_input, timeout=10)
417
- response.raise_for_status()
418
- image = Image.open(io.BytesIO(response.content))
419
- elif os.path.exists(image_input): # File path
420
- image = Image.open(image_input)
421
- else:
422
- raise ValueError(f"Invalid image input string: {image_input}")
423
- elif isinstance(image_input, np.ndarray):
424
- image = Image.fromarray(image_input)
425
- elif isinstance(image_input, Image.Image):
426
- image = image_input # Already a PIL image
427
- else:
428
- raise TypeError(f"Unsupported image input type: {type(image_input)}")
429
-
430
- # Preprocess the PIL image
431
- original_pil_image, input_tensor = preprocess_image(image)
432
-
433
- # Ensure input tensor is float32, as expected by most ONNX models
434
- # (even if the model internally uses float16)
435
- input_tensor = input_tensor.astype(np.float32)
436
-
437
- except Exception as e:
438
- message = f"Error processing input image: {e}"
439
- print(message)
440
- return message, None
441
-
442
- # --- 4. Run Inference ---
443
- try:
444
- input_name = session.get_inputs()[0].name
445
- output_name = session.get_outputs()[0].name
446
- print(f"Running inference with input '{input_name}', output '{output_name}'")
447
- start_time = time.time()
448
- outputs = session.run([output_name], {input_name: input_tensor})[0]
449
- inference_time = time.time() - start_time
450
- print(f"Inference completed in {inference_time:.3f} seconds")
451
-
452
- # Check for NaN/Inf in outputs
453
- if np.isnan(outputs).any() or np.isinf(outputs).any():
454
- print("Warning: NaN or Inf detected in model output. Clamping...")
455
- outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
456
-
457
- # Apply sigmoid (outputs are likely logits)
458
- # Use a stable sigmoid implementation
459
- def stable_sigmoid(x):
460
- return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
461
- probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
462
-
463
- except Exception as e:
464
- message = f"Error during ONNX inference: {e}"
465
- print(message)
466
- import traceback; traceback.print_exc()
467
- return message, None
468
- finally:
469
- # Clean up session if needed (might reduce memory usage between clicks)
470
- del session
471
-
472
- # --- 5. Post-process and Format Output ---
473
- try:
474
- print("Post-processing results...")
475
- # Use the correct global variable for labels
476
- predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
477
-
478
- # Format output text string
479
- output_tags = []
480
- if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
481
- if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
482
- # Add other categories, respecting order and filtering meta if needed
483
- for category in ["artist", "character", "copyright", "general", "meta", "model"]:
484
- tags_in_category = predictions.get(category, [])
485
- for tag, prob in tags_in_category:
486
- # Basic meta tag filtering for text output
487
- if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
488
- continue
489
- output_tags.append(tag.replace("_", " "))
490
- output_text = ", ".join(output_tags)
491
-
492
- # Generate visualization if requested
493
- viz_image = None
494
- if output_mode == "Tags + Visualization":
495
- print("Generating visualization...")
496
- # Pass the correct threshold for display title (can pass both if needed)
497
- # For simplicity, passing gen_threshold as a representative value
498
- viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
499
- print("Visualization generated.")
500
- else:
501
- print("Visualization skipped.")
502
-
503
- print("Prediction complete.")
504
- return output_text, viz_image
505
-
506
- except Exception as e:
507
- message = f"Error during post-processing: {e}"
508
- print(message)
509
- import traceback; traceback.print_exc()
510
- return message, None
511
-
512
- # --- Gradio Interface Definition (Full ONNX Version) ---
513
- css = """
514
- .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
515
- footer { display: none !important; }
516
- .gr-prose { max-width: 100% !important; }
517
- """
518
- # js = """ /* Keep existing JS */ """ # No JS needed currently
519
-
520
- with gr.Blocks(css=css) as demo:
521
- gr.Markdown("# CL EVA02 ONNX Tagger")
522
- gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
523
-
524
- with gr.Row():
525
- with gr.Column(scale=1):
526
- image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
527
- model_choice = gr.Dropdown(
528
- choices=list(MODEL_OPTIONS.keys()),
529
- value=DEFAULT_MODEL,
530
- label="Model Version",
531
- interactive=True
532
- )
533
- gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta/Model Tag Threshold")
534
- char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
535
- output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
536
- predict_button = gr.Button("Predict", variant="primary")
537
- with gr.Column(scale=1):
538
- output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
539
- output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
540
-
541
- # Handle model change
542
- model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
543
- model_choice.change(
544
- fn=change_model,
545
- inputs=[model_choice],
546
- outputs=[model_status]
547
- )
548
-
549
- gr.Examples(
550
- examples=[
551
- ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"],
552
- ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags Only"],
553
- ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"],
554
- ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"]
555
- ],
556
- inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
557
- outputs=[output_tags, output_visualization],
558
- fn=predict_onnx, # Use the ONNX prediction function
559
- cache_examples=False # Disable caching for examples during testing
560
- )
561
- predict_button.click(
562
- fn=predict_onnx, # Use the ONNX prediction function
563
- inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
564
- outputs=[output_tags, output_visualization]
565
- )
566
-
567
- # --- Main Block ---
568
- if __name__ == "__main__":
569
- if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
570
- # Initialize paths and labels at startup (with default model)
571
- initialize_onnx_paths(DEFAULT_MODEL)
572
- # Launch Gradio app
573
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
4
+ # from PIL import Image, ImageDraw, ImageFont # No drawing yet
5
+ import json
6
+ import os
7
+ import io
8
+ import requests
9
+ import matplotlib.pyplot as plt # For visualization
10
+ import matplotlib # For backend setting
11
+ from huggingface_hub import hf_hub_download
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Optional, Tuple
14
+ import time
15
+ import spaces # Required for @spaces.GPU
16
+ import onnxruntime as ort # Use ONNX Runtime
17
+
18
+ import torch # Keep torch for device check in Tagger
19
+ import timm # Restore timm
20
+ from safetensors.torch import load_file as safe_load_file # Restore safetensors loading
21
+
22
+ # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
23
+ # matplotlib.use('Agg')
24
+
25
+ # --- Data Classes and Helper Functions ---
26
+ @dataclass
27
+ class LabelData:
28
+ names: list[str]
29
+ rating: list[np.int64]
30
+ general: list[np.int64]
31
+ artist: list[np.int64]
32
+ character: list[np.int64]
33
+ copyright: list[np.int64]
34
+ meta: list[np.int64]
35
+ quality: list[np.int64]
36
+ model: list[np.int64]
37
+
38
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
39
+ if image.mode not in ["RGB", "RGBA"]:
40
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
41
+ if image.mode == "RGBA":
42
+ background = Image.new("RGB", image.size, (255, 255, 255))
43
+ background.paste(image, mask=image.split()[3])
44
+ image = background
45
+ return image
46
+
47
+ def pil_pad_square(image: Image.Image) -> Image.Image:
48
+ width, height = image.size
49
+ if width == height: return image
50
+ new_size = max(width, height)
51
+ new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
52
+ paste_position = ((new_size - width) // 2, (new_size - height) // 2)
53
+ new_image.paste(image, paste_position)
54
+ return new_image
55
+
56
+ def load_tag_mapping(mapping_path):
57
+ # Use the implementation from the original app.py as it was confirmed working
58
+ with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
59
+ # Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
60
+ if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
61
+ idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
62
+ tag_to_category = tag_mapping_data["tag_to_category"]
63
+ elif isinstance(tag_mapping_data, dict):
64
+ # Assuming the dict-of-dicts format from previous tests
65
+ try:
66
+ tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
67
+ idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
68
+ tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
69
+ except (KeyError, ValueError) as e:
70
+ raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
71
+ else:
72
+ raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
73
+
74
+ names = [None] * (max(idx_to_tag.keys()) + 1)
75
+ rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
76
+ for idx, tag in idx_to_tag.items():
77
+ if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
78
+ names[idx] = tag
79
+ category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
80
+ idx_int = int(idx)
81
+ if category == 'Rating': rating.append(idx_int)
82
+ elif category == 'General': general.append(idx_int)
83
+ elif category == 'Artist': artist.append(idx_int)
84
+ elif category == 'Character': character.append(idx_int)
85
+ elif category == 'Copyright': copyright.append(idx_int)
86
+ elif category == 'Meta': meta.append(idx_int)
87
+ elif category == 'Quality': quality.append(idx_int)
88
+ elif category == 'Model': model_name.append(idx_int)
89
+
90
+ return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
91
+ character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
92
+
93
+ def preprocess_image(image: Image.Image, target_size=(448, 448)):
94
+ # Adapted from onnx_predict.py's version
95
+ image = pil_ensure_rgb(image)
96
+ image = pil_pad_square(image)
97
+ image_resized = image.resize(target_size, Image.BICUBIC)
98
+ img_array = np.array(image_resized, dtype=np.float32) / 255.0
99
+ img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
100
+ # Assuming model expects RGB based on original code, no BGR conversion here
101
+ img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
102
+ mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
103
+ std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
104
+ img_array = (img_array - mean) / std
105
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
106
+ return image, img_array
107
+
108
+ # Add get_tags function (from onnx_predict.py)
109
+ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
110
+ result = {
111
+ "rating": [],
112
+ "general": [],
113
+ "character": [],
114
+ "copyright": [],
115
+ "artist": [],
116
+ "meta": [],
117
+ "quality": [],
118
+ "model": []
119
+ }
120
+ # Rating (select max)
121
+ if len(labels.rating) > 0:
122
+ # Ensure indices are within bounds
123
+ valid_indices = labels.rating[labels.rating < len(probs)]
124
+ if len(valid_indices) > 0:
125
+ rating_probs = probs[valid_indices]
126
+ if len(rating_probs) > 0:
127
+ rating_idx_local = np.argmax(rating_probs)
128
+ rating_idx_global = valid_indices[rating_idx_local]
129
+ # Check if global index is valid for names list
130
+ if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
131
+ rating_name = labels.names[rating_idx_global]
132
+ rating_conf = float(rating_probs[rating_idx_local])
133
+ result["rating"].append((rating_name, rating_conf))
134
+ else:
135
+ print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
136
+ else:
137
+ print("Warning: rating_probs became empty after filtering.")
138
+ else:
139
+ print("Warning: No valid indices found for rating tags within probs length.")
140
+
141
+ # Quality (select max)
142
+ if len(labels.quality) > 0:
143
+ valid_indices = labels.quality[labels.quality < len(probs)]
144
+ if len(valid_indices) > 0:
145
+ quality_probs = probs[valid_indices]
146
+ if len(quality_probs) > 0:
147
+ quality_idx_local = np.argmax(quality_probs)
148
+ quality_idx_global = valid_indices[quality_idx_local]
149
+ if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
150
+ quality_name = labels.names[quality_idx_global]
151
+ quality_conf = float(quality_probs[quality_idx_local])
152
+ result["quality"].append((quality_name, quality_conf))
153
+ else:
154
+ print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
155
+ else:
156
+ print("Warning: quality_probs became empty after filtering.")
157
+ else:
158
+ print("Warning: No valid indices found for quality tags within probs length.")
159
+
160
+ # Threshold-based categories
161
+ category_map = {
162
+ "general": (labels.general, gen_threshold),
163
+ "character": (labels.character, char_threshold),
164
+ "copyright": (labels.copyright, char_threshold),
165
+ "artist": (labels.artist, char_threshold),
166
+ "meta": (labels.meta, gen_threshold),
167
+ "model": (labels.model, gen_threshold)
168
+ }
169
+ for category, (indices, threshold) in category_map.items():
170
+ if len(indices) > 0:
171
+ valid_indices = indices[(indices < len(probs))] # Check index bounds first
172
+ if len(valid_indices) > 0:
173
+ category_probs = probs[valid_indices]
174
+ mask = category_probs >= threshold
175
+ selected_indices_local = np.where(mask)[0]
176
+ if len(selected_indices_local) > 0:
177
+ selected_indices_global = valid_indices[selected_indices_local]
178
+ selected_probs = category_probs[selected_indices_local]
179
+ for idx_global, prob_val in zip(selected_indices_global, selected_probs):
180
+ # Check if global index is valid for names list
181
+ if idx_global < len(labels.names) and labels.names[idx_global] is not None:
182
+ result[category].append((labels.names[idx_global], float(prob_val)))
183
+ else:
184
+ print(f"Warning: Invalid global index {idx_global} for {category} tag.")
185
+ # else: print(f"No tags found for category '{category}' above threshold {threshold}")
186
+ # else: print(f"No valid indices found for category '{category}' within probs length.")
187
+ # else: print(f"No indices defined for category '{category}'")
188
+
189
+ # Sort by probability (descending)
190
+ for k in result:
191
+ result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
192
+ return result
193
+
194
+ # Add visualize_predictions function (Adapted from onnx_predict.py and previous versions)
195
+ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: float):
196
+ # Filter out unwanted meta tags (e.g., id, commentary, request, mismatch)
197
+ filtered_meta = []
198
+ excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
199
+ for tag, prob in predictions.get("meta", []):
200
+ if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
201
+ filtered_meta.append((tag, prob))
202
+ predictions["meta"] = filtered_meta # Use filtered list for visualization
203
+
204
+ # --- Plotting Setup ---
205
+ plt.rcParams['font.family'] = 'DejaVu Sans'
206
+ fig = plt.figure(figsize=(8, 12), dpi=100)
207
+ ax_tags = fig.add_subplot(1, 1, 1)
208
+
209
+ all_tags, all_probs, all_colors = [], [], []
210
+ color_map = {
211
+ 'rating': 'red', 'character': 'blue', 'copyright': 'purple',
212
+ 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow', 'model': 'cyan'
213
+ }
214
+
215
+ # Aggregate tags from predictions dictionary
216
+ for cat, prefix, color in [
217
+ ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
218
+ ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
219
+ ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
220
+ ('meta', 'M', color_map['meta']), ('model', 'M', color_map['model'])
221
+ ]:
222
+ sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
223
+ for tag, prob in sorted_tags:
224
+ all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
225
+ all_probs.append(prob)
226
+ all_colors.append(color)
227
+
228
+ if not all_tags:
229
+ ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
230
+ ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
231
+ ax_tags.axis('off')
232
+ else:
233
+ sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
234
+ all_tags = [all_tags[i] for i in sorted_indices]
235
+ all_probs = [all_probs[i] for i in sorted_indices]
236
+ all_colors = [all_colors[i] for i in sorted_indices]
237
+
238
+ num_tags = len(all_tags)
239
+ bar_height = min(0.8, max(0.1, 0.8 * (30 / num_tags))) if num_tags > 30 else 0.8
240
+ y_positions = np.arange(num_tags)
241
+
242
+ bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
243
+ ax_tags.set_yticks(y_positions)
244
+ ax_tags.set_yticklabels(all_tags)
245
+
246
+ fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
247
+ for lbl in ax_tags.get_yticklabels():
248
+ lbl.set_fontsize(fontsize)
249
+
250
+ for i, (bar, prob) in enumerate(zip(bars, all_probs)):
251
+ text_x = min(prob + 0.02, 0.98)
252
+ ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
253
+
254
+ ax_tags.set_xlim(0, 1)
255
+ ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
256
+
257
+ from matplotlib.patches import Patch
258
+ legend_elements = [
259
+ Patch(facecolor=color, label=cat.capitalize())
260
+ for cat, color in color_map.items()
261
+ if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
262
+ ]
263
+ if legend_elements:
264
+ ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
265
+
266
+ plt.tight_layout()
267
+ buf = io.BytesIO()
268
+ plt.savefig(buf, format='png', dpi=100)
269
+ plt.close(fig)
270
+ buf.seek(0)
271
+ return Image.open(buf)
272
+
273
+ # --- Constants ---
274
+ REPO_ID = "celstk/wd-eva02-lora-onnx"
275
+ # Model options
276
+ MODEL_OPTIONS = {
277
+ "cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
278
+ # "cl_eva02_tagger_v1_250502": "cl_eva02_tagger_v1_250503/model.onnx",
279
+ # "cl_eva02_tagger_v1_250509": "cl_eva02_tagger_v1_250509/model.onnx",
280
+ "cl_eva02_tagger_v1_250511": "cl_eva02_tagger_v1_250511/model.onnx",
281
+ # "cl_eva02_tagger_v1_250512": "cl_eva02_tagger_v1_250512/model.onnx",
282
+ # "cl_eva02_tagger_v1_250513": "cl_eva02_tagger_v1_250513/model.onnx",
283
+ # "cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
284
+ # "cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
285
+ # "cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
286
+ # "cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx",
287
+ # "cl_eva02_tagger_v1_250523": "cl_eva02_tagger_v1_250523/model.onnx",
288
+ # "cl_eva02_tagger_v1_250702": "cl_eva02_tagger_v1_250702/model.onnx",
289
+ # "cl_eva02_tagger_v1_250704": "cl_eva02_tagger_v1_250704/model.onnx",
290
+ # "cl_eva02_tagger_v1_250705": "cl_eva02_tagger_v1_250705/model.onnx",
291
+ # "cl_eva02_tagger_v1_250706": "cl_eva02_tagger_v1_250706/model.onnx",
292
+ "cl_eva02_tagger_v1_250707": "cl_eva02_tagger_v1_250707/model.onnx",
293
+ # # "cl_eva02_tagger_v1_250708": "cl_eva02_tagger_v1_250708/model.onnx",
294
+ # "cl_eva02_tagger_v1_250717": "cl_eva02_tagger_v1_250717/model.onnx",
295
+ # "cl_eva02_tagger_v1_250719": "cl_eva02_tagger_v1_250719/model.onnx",
296
+ # "cl_eva02_tagger_v1_250720": "cl_eva02_tagger_v1_250720/model.onnx",
297
+ # "cl_eva02_tagger_v1_250725": "cl_eva02_tagger_v1_250725/model.onnx",
298
+ # "cl_eva02_tagger_v1_250729": "cl_eva02_tagger_v1_250729/model.onnx",
299
+ # "cl_eva02_tagger_v1_250801": "cl_eva02_tagger_v1_250801/model.onnx",
300
+ "cl_eva02_tagger_v1_250804": "cl_eva02_tagger_v1_250804/model.onnx",
301
+ "cl_eva02_tagger_v1_250807": "cl_eva02_tagger_v1_250807/model.onnx"
302
+ }
303
+ DEFAULT_MODEL = "cl_eva02_tagger_v1_250807"
304
+ CACHE_DIR = "./model_cache"
305
+
306
+ # --- Global variables for paths (initialized at startup) ---
307
+ g_onnx_model_path = None
308
+ g_tag_mapping_path = None
309
+ g_labels_data = None
310
+ g_idx_to_tag = None
311
+ g_tag_to_category = None
312
+ g_current_model = None
313
+
314
+ # --- Initialization Function ---
315
+ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
316
+ global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
317
+
318
+ if not model_choice in MODEL_OPTIONS:
319
+ print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
320
+ model_choice = DEFAULT_MODEL
321
+
322
+ g_current_model = model_choice
323
+ model_dir = model_choice
324
+ onnx_filename = MODEL_OPTIONS[model_choice]
325
+ tag_mapping_filename = f"{model_dir}/tag_mapping.json"
326
+
327
+ print(f"Initializing ONNX paths and labels for model: {model_choice}...")
328
+ hf_token = os.environ.get("HF_TOKEN")
329
+ try:
330
+ print(f"Attempting to download ONNX model: {onnx_filename}")
331
+ g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
332
+ print(f"ONNX model path: {g_onnx_model_path}")
333
+
334
+ print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
335
+ g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
336
+ print(f"Tag mapping path: {g_tag_mapping_path}")
337
+
338
+ print("Loading labels from mapping...")
339
+ g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
340
+ print(f"Labels loaded. Count: {len(g_labels_data.names)}")
341
+
342
+ return True
343
+
344
+ except Exception as e:
345
+ print(f"Error during initialization: {e}")
346
+ import traceback; traceback.print_exc()
347
+ # Reset globals to force reinitialization
348
+ g_onnx_model_path = None
349
+ g_tag_mapping_path = None
350
+ g_labels_data = None
351
+ g_idx_to_tag = None
352
+ g_tag_to_category = None
353
+ g_current_model = None
354
+ # Raise Gradio error to make it visible in the UI
355
+ raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
356
+
357
+ # Function to handle model change
358
+ def change_model(model_choice):
359
+ try:
360
+ success = initialize_onnx_paths(model_choice)
361
+ if success:
362
+ return f"Model changed to: {model_choice}"
363
+ else:
364
+ return "Failed to change model. See logs for details."
365
+ except Exception as e:
366
+ return f"Error changing model: {str(e)}"
367
+
368
+ # --- Main Prediction Function (ONNX) ---
369
+ @spaces.GPU()
370
+ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
371
+ print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
372
+
373
+ # Ensure current model matches selected model
374
+ global g_current_model
375
+ if g_current_model != model_choice:
376
+ print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
377
+ try:
378
+ initialize_onnx_paths(model_choice)
379
+ except Exception as e:
380
+ return f"Error initializing model '{model_choice}': {str(e)}", None
381
+
382
+ # --- 1. Ensure paths and labels are loaded ---
383
+ if g_onnx_model_path is None or g_labels_data is None:
384
+ message = "Error: Paths or labels not initialized. Check startup logs."
385
+ print(message)
386
+ # Return error message and None for the image output
387
+ return message, None
388
+
389
+ # --- 2. Load ONNX Session (inside worker) ---
390
+ session = None
391
+ try:
392
+ print(f"Loading ONNX session from: {g_onnx_model_path}")
393
+ available_providers = ort.get_available_providers()
394
+ providers = []
395
+ if 'CUDAExecutionProvider' in available_providers:
396
+ providers.append('CUDAExecutionProvider')
397
+ providers.append('CPUExecutionProvider')
398
+ print(f"Attempting to load session with providers: {providers}")
399
+ session = ort.InferenceSession(g_onnx_model_path, providers=providers)
400
+ print(f"ONNX session loaded using: {session.get_providers()[0]}")
401
+ except Exception as e:
402
+ message = f"Error loading ONNX session in worker: {e}"
403
+ print(message)
404
+ import traceback; traceback.print_exc()
405
+ return message, None
406
+
407
+ # --- 3. Process Input Image ---
408
+ if image_input is None:
409
+ return "Please upload an image.", None
410
+
411
+ print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
412
+ try:
413
+ # Handle different input types (PIL, numpy, URL, file path)
414
+ if isinstance(image_input, str):
415
+ if image_input.startswith("http"): # URL
416
+ response = requests.get(image_input, timeout=10)
417
+ response.raise_for_status()
418
+ image = Image.open(io.BytesIO(response.content))
419
+ elif os.path.exists(image_input): # File path
420
+ image = Image.open(image_input)
421
+ else:
422
+ raise ValueError(f"Invalid image input string: {image_input}")
423
+ elif isinstance(image_input, np.ndarray):
424
+ image = Image.fromarray(image_input)
425
+ elif isinstance(image_input, Image.Image):
426
+ image = image_input # Already a PIL image
427
+ else:
428
+ raise TypeError(f"Unsupported image input type: {type(image_input)}")
429
+
430
+ # Preprocess the PIL image
431
+ original_pil_image, input_tensor = preprocess_image(image)
432
+
433
+ # Ensure input tensor is float32, as expected by most ONNX models
434
+ # (even if the model internally uses float16)
435
+ input_tensor = input_tensor.astype(np.float32)
436
+
437
+ except Exception as e:
438
+ message = f"Error processing input image: {e}"
439
+ print(message)
440
+ return message, None
441
+
442
+ # --- 4. Run Inference ---
443
+ try:
444
+ input_name = session.get_inputs()[0].name
445
+ output_name = session.get_outputs()[0].name
446
+ print(f"Running inference with input '{input_name}', output '{output_name}'")
447
+ start_time = time.time()
448
+ outputs = session.run([output_name], {input_name: input_tensor})[0]
449
+ inference_time = time.time() - start_time
450
+ print(f"Inference completed in {inference_time:.3f} seconds")
451
+
452
+ # Check for NaN/Inf in outputs
453
+ if np.isnan(outputs).any() or np.isinf(outputs).any():
454
+ print("Warning: NaN or Inf detected in model output. Clamping...")
455
+ outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
456
+
457
+ # Apply sigmoid (outputs are likely logits)
458
+ # Use a stable sigmoid implementation
459
+ def stable_sigmoid(x):
460
+ return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
461
+ probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
462
+
463
+ except Exception as e:
464
+ message = f"Error during ONNX inference: {e}"
465
+ print(message)
466
+ import traceback; traceback.print_exc()
467
+ return message, None
468
+ finally:
469
+ # Clean up session if needed (might reduce memory usage between clicks)
470
+ del session
471
+
472
+ # --- 5. Post-process and Format Output ---
473
+ try:
474
+ print("Post-processing results...")
475
+ # Use the correct global variable for labels
476
+ predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
477
+
478
+ # Format output text string
479
+ output_tags = []
480
+ if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
481
+ if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
482
+ # Add other categories, respecting order and filtering meta if needed
483
+ for category in ["artist", "character", "copyright", "general", "meta", "model"]:
484
+ tags_in_category = predictions.get(category, [])
485
+ for tag, prob in tags_in_category:
486
+ # Basic meta tag filtering for text output
487
+ if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
488
+ continue
489
+ output_tags.append(tag.replace("_", " "))
490
+ output_text = ", ".join(output_tags)
491
+
492
+ # Generate visualization if requested
493
+ viz_image = None
494
+ if output_mode == "Tags + Visualization":
495
+ print("Generating visualization...")
496
+ # Pass the correct threshold for display title (can pass both if needed)
497
+ # For simplicity, passing gen_threshold as a representative value
498
+ viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
499
+ print("Visualization generated.")
500
+ else:
501
+ print("Visualization skipped.")
502
+
503
+ print("Prediction complete.")
504
+ return output_text, viz_image
505
+
506
+ except Exception as e:
507
+ message = f"Error during post-processing: {e}"
508
+ print(message)
509
+ import traceback; traceback.print_exc()
510
+ return message, None
511
+
512
+ # --- Gradio Interface Definition (Full ONNX Version) ---
513
+ css = """
514
+ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
515
+ footer { display: none !important; }
516
+ .gr-prose { max-width: 100% !important; }
517
+ """
518
+ # js = """ /* Keep existing JS */ """ # No JS needed currently
519
+
520
+ with gr.Blocks(css=css) as demo:
521
+ gr.Markdown("# CL EVA02 ONNX Tagger")
522
+ gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
523
+
524
+ with gr.Row():
525
+ with gr.Column(scale=1):
526
+ image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
527
+ model_choice = gr.Dropdown(
528
+ choices=list(MODEL_OPTIONS.keys()),
529
+ value=DEFAULT_MODEL,
530
+ label="Model Version",
531
+ interactive=True
532
+ )
533
+ gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta/Model Tag Threshold")
534
+ char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
535
+ output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
536
+ predict_button = gr.Button("Predict", variant="primary")
537
+ with gr.Column(scale=1):
538
+ output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
539
+ output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
540
+
541
+ # Handle model change
542
+ model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
543
+ model_choice.change(
544
+ fn=change_model,
545
+ inputs=[model_choice],
546
+ outputs=[model_status]
547
+ )
548
+
549
+ gr.Examples(
550
+ examples=[
551
+ ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"],
552
+ ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags Only"],
553
+ ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"],
554
+ ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.55, "Tags + Visualization"]
555
+ ],
556
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
557
+ outputs=[output_tags, output_visualization],
558
+ fn=predict_onnx, # Use the ONNX prediction function
559
+ cache_examples=False # Disable caching for examples during testing
560
+ )
561
+ predict_button.click(
562
+ fn=predict_onnx, # Use the ONNX prediction function
563
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
564
+ outputs=[output_tags, output_visualization]
565
+ )
566
+
567
+ # --- Main Block ---
568
+ if __name__ == "__main__":
569
+ if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
570
+ # Initialize paths and labels at startup (with default model)
571
+ initialize_onnx_paths(DEFAULT_MODEL)
572
+ # Launch Gradio app
573
+ demo.launch()