cella110n commited on
Commit
2382b82
·
verified ·
1 Parent(s): 3eb9532

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +597 -599
  2. requirements.txt +10 -9
app.py CHANGED
@@ -1,600 +1,598 @@
1
- import gradio as gr
2
- import spaces
3
- import onnxruntime as ort
4
- import numpy as np
5
- from PIL import Image, ImageDraw, ImageFont
6
- import json
7
- import os
8
- import io
9
- import requests
10
- import matplotlib.pyplot as plt
11
- import matplotlib
12
- from huggingface_hub import hf_hub_download
13
- from dataclasses import dataclass
14
- from typing import List, Dict, Optional, Tuple
15
-
16
- # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
17
- matplotlib.use('Agg')
18
-
19
- # --- onnx_predict.pyからの移植 ---
20
-
21
- @dataclass
22
- class LabelData:
23
- names: list[str]
24
- rating: list[np.int64]
25
- general: list[np.int64]
26
- artist: list[np.int64]
27
- character: list[np.int64]
28
- copyright: list[np.int64]
29
- meta: list[np.int64]
30
- quality: list[np.int64]
31
-
32
- def pil_ensure_rgb(image: Image.Image) -> Image.Image:
33
- if image.mode not in ["RGB", "RGBA"]:
34
- image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
35
- if image.mode == "RGBA":
36
- background = Image.new("RGB", image.size, (255, 255, 255))
37
- background.paste(image, mask=image.split()[3])
38
- image = background
39
- return image
40
-
41
- def pil_pad_square(image: Image.Image) -> Image.Image:
42
- width, height = image.size
43
- if width == height:
44
- return image
45
- new_size = max(width, height)
46
- new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
47
- paste_position = ((new_size - width) // 2, (new_size - height) // 2)
48
- new_image.paste(image, paste_position)
49
- return new_image
50
-
51
- def load_tag_mapping(mapping_path):
52
- with open(mapping_path, 'r', encoding='utf-8') as f:
53
- tag_mapping_data = json.load(f)
54
-
55
- # 新旧フォーマット対応
56
- if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
57
- # 旧フォーマット (辞書の中にidx_to_tagとtag_to_categoryがある)
58
- idx_to_tag_dict = tag_mapping_data["idx_to_tag"]
59
- tag_to_category_dict = tag_mapping_data["tag_to_category"]
60
- # tag_mapping_dataが文字列キーになっている可能性があるのでintに変換
61
- idx_to_tag = {int(k): v for k, v in idx_to_tag_dict.items()}
62
- tag_to_category = tag_to_category_dict
63
- elif isinstance(tag_mapping_data, dict):
64
- # 新フォーマット (キーがインデックスの辞書)
65
- tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
66
- idx_to_tag = {}
67
- tag_to_category = {}
68
- for idx, data in tag_mapping_data.items():
69
- tag = data['tag']
70
- category = data['category']
71
- idx_to_tag[idx] = tag
72
- tag_to_category[tag] = category
73
- else:
74
- raise ValueError("Unsupported tag mapping format")
75
-
76
-
77
- names = [None] * (max(idx_to_tag.keys()) + 1)
78
- rating = []
79
- general = []
80
- artist = []
81
- character = []
82
- copyright = []
83
- meta = []
84
- quality = []
85
-
86
- for idx, tag in idx_to_tag.items():
87
- if idx >= len(names): # namesリストのサイズが足りない場合拡張
88
- names.extend([None] * (idx - len(names) + 1))
89
- names[idx] = tag
90
- category = tag_to_category.get(tag, 'Unknown') # カテゴリが見つからない場合
91
-
92
- if category == 'Rating':
93
- rating.append(idx)
94
- elif category == 'General':
95
- general.append(idx)
96
- elif category == 'Artist':
97
- artist.append(idx)
98
- elif category == 'Character':
99
- character.append(idx)
100
- elif category == 'Copyright':
101
- copyright.append(idx)
102
- elif category == 'Meta':
103
- meta.append(idx)
104
- elif category == 'Quality':
105
- quality.append(idx)
106
- # Unknownカテゴリは無視
107
-
108
- label_data = LabelData(
109
- names=names,
110
- rating=np.array(rating, dtype=np.int64),
111
- general=np.array(general, dtype=np.int64),
112
- artist=np.array(artist, dtype=np.int64),
113
- character=np.array(character, dtype=np.int64),
114
- copyright=np.array(copyright, dtype=np.int64),
115
- meta=np.array(meta, dtype=np.int64),
116
- quality=np.array(quality, dtype=np.int64)
117
- )
118
-
119
- return label_data, idx_to_tag, tag_to_category
120
-
121
-
122
- def preprocess_image(image: Image.Image, target_size=(448, 448)):
123
- image = pil_ensure_rgb(image)
124
- image = pil_pad_square(image)
125
- image_resized = image.resize(target_size, Image.BICUBIC)
126
- img_array = np.array(image_resized, dtype=np.float32) / 255.0
127
- img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
128
- # RGB -> BGR (モデルがBGRを期待する場合 - WD Tagger v3はBGR)
129
- # WD Tagger V2/V1はRGBなので注意
130
- img_array = img_array[::-1, :, :]
131
- mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
132
- std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
133
- img_array = (img_array - mean) / std
134
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
135
- return image, img_array # Return original PIL image and processed numpy array
136
-
137
- def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
138
- result = {
139
- "rating": [], "general": [], "character": [],
140
- "copyright": [], "artist": [], "meta": [], "quality": []
141
- }
142
-
143
- # Rating (select the max)
144
- if labels.rating.size > 0:
145
- rating_probs = probs[labels.rating]
146
- if rating_probs.size > 0:
147
- rating_idx = np.argmax(rating_probs)
148
- # Check if the index is valid for names list
149
- if labels.rating[rating_idx] < len(labels.names):
150
- rating_name = labels.names[labels.rating[rating_idx]]
151
- rating_conf = float(rating_probs[rating_idx])
152
- result["rating"].append((rating_name, rating_conf))
153
- else:
154
- print(f"Warning: Rating index {labels.rating[rating_idx]} out of bounds for names list (size {len(labels.names)}).")
155
-
156
-
157
- # Quality (select the max)
158
- if labels.quality.size > 0:
159
- quality_probs = probs[labels.quality]
160
- if quality_probs.size > 0:
161
- quality_idx = np.argmax(quality_probs)
162
- if labels.quality[quality_idx] < len(labels.names):
163
- quality_name = labels.names[labels.quality[quality_idx]]
164
- quality_conf = float(quality_probs[quality_idx])
165
- result["quality"].append((quality_name, quality_conf))
166
- else:
167
- print(f"Warning: Quality index {labels.quality[quality_idx]} out of bounds for names list (size {len(labels.names)}).")
168
-
169
-
170
- category_map = {
171
- "general": (labels.general, gen_threshold),
172
- "character": (labels.character, char_threshold),
173
- "copyright": (labels.copyright, char_threshold),
174
- "artist": (labels.artist, char_threshold),
175
- "meta": (labels.meta, gen_threshold)
176
- }
177
-
178
- for category, (indices, threshold) in category_map.items():
179
- if indices.size > 0:
180
- # Filter indices to be within the bounds of probs and labels.names
181
- valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))]
182
- if valid_indices.size > 0:
183
- category_probs = probs[valid_indices]
184
- mask = category_probs >= threshold
185
- selected_indices = valid_indices[mask]
186
- selected_probs = category_probs[mask]
187
- for idx, prob in zip(selected_indices, selected_probs):
188
- result[category].append((labels.names[idx], float(prob)))
189
-
190
-
191
- # Sort by probability
192
- for k in result:
193
- result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
194
-
195
- return result
196
-
197
- def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
198
- # Filter out unwanted meta tags
199
- filtered_meta = []
200
- excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
201
- for tag, prob in predictions["meta"]:
202
- if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
203
- filtered_meta.append((tag, prob))
204
- predictions["meta"] = filtered_meta # Replace with filtered
205
-
206
- # Create plot
207
- fig = plt.figure(figsize=(20, 12), dpi=100)
208
- gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
209
- ax_img = fig.add_subplot(gs[0, 0])
210
- ax_img.imshow(image)
211
- ax_img.set_title("Original Image")
212
- ax_img.axis('off')
213
- ax_tags = fig.add_subplot(gs[0, 1])
214
-
215
- all_tags = []
216
- all_probs = []
217
- all_colors = []
218
- color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple',
219
- 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'}
220
-
221
- for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'),
222
- ('copyright', '©', 'purple'), ('artist', 'A', 'orange'),
223
- ('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]:
224
- for tag, prob in predictions[cat]:
225
- all_tags.append(f"[{prefix}] {tag}")
226
- all_probs.append(prob)
227
- all_colors.append(color)
228
-
229
- if not all_tags:
230
- ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
231
- ax_tags.set_title(f"Tags (threshold={threshold})")
232
- ax_tags.axis('off')
233
- plt.tight_layout()
234
- # Save figure to a BytesIO object
235
- buf = io.BytesIO()
236
- plt.savefig(buf, format='png', dpi=100)
237
- plt.close(fig)
238
- buf.seek(0)
239
- return Image.open(buf)
240
-
241
-
242
- sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i], reverse=True)
243
- all_tags = [all_tags[i] for i in sorted_indices]
244
- all_probs = [all_probs[i] for i in sorted_indices]
245
- all_colors = [all_colors[i] for i in sorted_indices]
246
-
247
- all_tags.reverse()
248
- all_probs.reverse()
249
- all_colors.reverse()
250
-
251
- num_tags = len(all_tags)
252
- bar_height = 0.8
253
- if num_tags > 30: bar_height = 0.8 * (30 / num_tags)
254
- y_positions = np.arange(num_tags)
255
-
256
- bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
257
- ax_tags.set_yticks(y_positions)
258
- ax_tags.set_yticklabels(all_tags)
259
-
260
- fontsize = 10
261
- if num_tags > 40: fontsize = 8
262
- elif num_tags > 60: fontsize = 6
263
- for label in ax_tags.get_yticklabels(): label.set_fontsize(fontsize)
264
-
265
- for i, (bar, prob) in enumerate(zip(bars, all_probs)):
266
- ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}",
267
- va='center', fontsize=fontsize)
268
-
269
- ax_tags.set_xlim(0, 1)
270
- ax_tags.set_title(f"Tags (threshold={threshold})")
271
-
272
- from matplotlib.patches import Patch
273
- legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items()]
274
- ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
275
-
276
- plt.tight_layout()
277
- plt.subplots_adjust(bottom=0.05)
278
-
279
- # Save figure to a BytesIO object
280
- buf = io.BytesIO()
281
- plt.savefig(buf, format='png', dpi=100)
282
- plt.close(fig)
283
- buf.seek(0)
284
- return Image.open(buf)
285
-
286
- # --- Gradio App Logic ---
287
-
288
- # 定数
289
- REPO_ID = "cella110n/cl_tagger"
290
- # MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model_optimized.onnx"
291
- MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx" # Use non-optimized if needed
292
- TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
293
- CACHE_DIR = "./model_cache"
294
-
295
- # グローバル変数(モデルとラベルをキャッシュ)
296
- onnx_session = None
297
- labels_data = None
298
- tag_to_category_map = None
299
-
300
- def download_model_files():
301
- """Hugging Face Hubからモデルとタグマッピングをダウンロード"""
302
- print("Downloading model files...")
303
- # 環境変数からHFトークンを取得 (プライベートリポジトリ用)
304
- hf_token = os.environ.get("HF_TOKEN")
305
- try:
306
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
307
- tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
308
- print(f"Model downloaded to: {model_path}")
309
- print(f"Tag mapping downloaded to: {tag_mapping_path}")
310
- return model_path, tag_mapping_path
311
- except Exception as e:
312
- print(f"Error downloading files: {e}")
313
- # トークンがない場合のエラーメッセージを改善
314
- if "401 Client Error" in str(e) or "Repository not found" in str(e):
315
- raise gr.Error(f"Could not download files from {REPO_ID}. "
316
- f"If this is a private repository, make sure to set the HF_TOKEN secret in your Space settings.")
317
- else:
318
- raise gr.Error(f"Error downloading files: {e}")
319
-
320
-
321
- def initialize_model():
322
- """モデルとラベルデータを初期化(キャッシュ)"""
323
- global onnx_session, labels_data, tag_to_category_map
324
- if onnx_session is None:
325
- model_path, tag_mapping_path = download_model_files()
326
- print("Loading model and labels...")
327
-
328
- # --- Added Logging ---
329
- print("--- Environment Check ---")
330
- try:
331
- import torch
332
- print(f"PyTorch version: {torch.__version__}")
333
- if torch.cuda.is_available():
334
- print(f"PyTorch CUDA available: True")
335
- print(f"PyTorch CUDA version: {torch.version.cuda}")
336
- print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
337
- if torch.backends.cudnn.is_available():
338
- print(f"PyTorch cuDNN available: True")
339
- print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}")
340
- else:
341
- print("PyTorch cuDNN available: False")
342
- else:
343
- print("PyTorch CUDA available: False")
344
- except ImportError:
345
- print("PyTorch not found.")
346
- except Exception as e:
347
- print(f"Error during PyTorch check: {e}")
348
-
349
- try:
350
- print(f"ONNX Runtime build info: {ort.get_buildinfo()}")
351
- except Exception as e:
352
- print(f"Error getting ONNX Runtime build info: {e}")
353
- print("-------------------------")
354
- # --- End Added Logging ---
355
-
356
- # ONNXセッションの初期化 (GPU優先)
357
- available_providers = ort.get_available_providers()
358
- print(f"Available ONNX Runtime providers: {available_providers}")
359
- providers = []
360
- if 'CUDAExecutionProvider' in available_providers:
361
- providers.append('CUDAExecutionProvider')
362
- # elif 'DmlExecutionProvider' in available_providers: # DirectML (Windows)
363
- # providers.append('DmlExecutionProvider')
364
- providers.append('CPUExecutionProvider') # Always include CPU as fallback
365
-
366
- try:
367
- onnx_session = ort.InferenceSession(model_path, providers=providers)
368
- print(f"Using ONNX Runtime provider: {onnx_session.get_providers()[0]}")
369
- except Exception as e:
370
- print(f"Error initializing ONNX session with providers {providers}: {e}")
371
- print("Falling back to CPUExecutionProvider only.")
372
- onnx_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
373
-
374
- labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
375
- print("Model and labels loaded.")
376
-
377
- @spaces.GPU()
378
- def predict(image_input, gen_threshold, char_threshold, output_mode):
379
- """Gradioインターフェース用の予測関数"""
380
- initialize_model() # モデルがロードされていなければロード
381
-
382
- if image_input is None:
383
- return "Please upload an image.", None
384
-
385
- print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
386
-
387
- # PIL Imageオブジェクトであることを確認
388
- if not isinstance(image_input, Image.Image):
389
- try:
390
- # URLの場合
391
- if isinstance(image_input, str) and image_input.startswith("http"):
392
- response = requests.get(image_input)
393
- response.raise_for_status()
394
- image = Image.open(io.BytesIO(response.content))
395
- # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
396
- elif isinstance(image_input, str) and os.path.exists(image_input):
397
- image = Image.open(image_input)
398
- # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
399
- elif isinstance(image_input, np.ndarray):
400
- image = Image.fromarray(image_input)
401
- else:
402
- raise ValueError("Unsupported image input type")
403
- except Exception as e:
404
- print(f"Error loading image: {e}")
405
- return f"Error loading image: {e}", None
406
- else:
407
- image = image_input
408
-
409
-
410
- # 前処理
411
- original_pil_image, input_data = preprocess_image(image)
412
-
413
- # データ型をモデルの期待に合わせる (通常はfloat32)
414
- input_name = onnx_session.get_inputs()[0].name
415
- expected_type = onnx_session.get_inputs()[0].type
416
- if expected_type == 'tensor(float16)':
417
- input_data = input_data.astype(np.float16)
418
- else:
419
- input_data = input_data.astype(np.float32) # Default to float32
420
-
421
- # 推論
422
- start_time = time.time()
423
- outputs = onnx_session.run(None, {input_name: input_data})[0]
424
- inference_time = time.time() - start_time
425
- print(f"Inference completed in {inference_time:.3f} seconds")
426
-
427
- # シグモイド関数で確率に変換
428
- probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
429
-
430
- # タグ取得
431
- predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
432
-
433
- # タグを整形
434
- output_tags = []
435
- # RatingとQualityを最初に追加
436
- if predictions["rating"]:
437
- output_tags.append(predictions["rating"][0][0].replace("_", " "))
438
- if predictions["quality"]:
439
- output_tags.append(predictions["quality"][0][0].replace("_", " "))
440
-
441
- # 残りのカテゴリをアルファベット順に追加(オプション)
442
- for category in ["artist", "character", "copyright", "general", "meta"]:
443
- tags = [tag.replace("_", " ") for tag, prob in predictions[category]
444
- if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
445
- output_tags.extend(tags)
446
-
447
- output_text = ", ".join(output_tags)
448
-
449
- if output_mode == "Tags Only":
450
- # Return the text and an update to make the visualization invisible
451
- return output_text, gr.update(value=None, visible=False)
452
- else: # Visualization
453
- viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
454
- # Return the text and an update to show the visualization
455
- return output_text, gr.update(value=viz_image, visible=True)
456
-
457
- # --- Gradio Interface Definition ---
458
- import time
459
-
460
- # CSS for styling
461
- css = """
462
- .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
463
- footer { display: none !important; }
464
- .gr-prose { max-width: 100% !important; }
465
- """
466
- # Custom JS for image pasting and URL handling
467
- js = """
468
- async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
469
- const data = await fetch(blob)
470
- const image_data = await data.blob()
471
- const file = new File([image_data], "pasted_image.png",{ type: image_data.type })
472
- const dt = new DataTransfer()
473
- dt.items.add(file)
474
- const element = document.querySelector('#input-image input[type="file"]')
475
- element.files = dt.files
476
- // Trigger the change event manually
477
- const event = new Event('change', { bubbles: true })
478
- element.dispatchEvent(event)
479
- // Wait a bit for Gradio to process the change, then trigger predict if needed
480
- // await new Promise(resolve => setTimeout(resolve, 100)); // Optional delay
481
- // You might need to manually trigger the prediction or rely on Gradio's auto-triggering
482
- return [file, gen_thresh, char_thresh, out_mode]; // Return input for Gradio function
483
- }
484
-
485
- async function paste_update(evt){
486
- if (!evt.clipboardData || !evt.clipboardData.items) return;
487
- var url = evt.clipboardData.getData('text');
488
- if (url) {
489
- // Basic check for image URL (you might want a more robust check)
490
- if (/\.(jpg|jpeg|png|webp|bmp)$/i.test(url)) {
491
- // Create a button or link to load the URL
492
- const url_container = document.getElementById('url-input-container');
493
- url_container.innerHTML = `<p>Detected URL: <button id="load-url-btn" class="gr-button gr-button-sm gr-button-secondary">${url}</button></p>`;
494
-
495
- document.getElementById('load-url-btn').onclick = async () => {
496
- // Simulate file upload from URL - Gradio's Image component handles URLs directly
497
- const element = document.querySelector('#input-image input[type="file"]');
498
- // Can't directly set URL to file input, so we pass it to Gradio fn
499
- // Or maybe update the image display src directly if possible?
500
-
501
- // Let Gradio handle the URL - user needs to click predict
502
- // We can pre-fill the image component if Gradio supports it via JS,
503
- // but it's simpler to just let the user click predict after pasting URL.
504
- alert("URL detected. Please ensure the image input is cleared and then press 'Predict' or re-upload the image.");
505
- // Clear current image preview if possible?
506
-
507
- // A workaround: display the URL and let the user manually trigger prediction
508
- // Or, try to use Gradio's JS API if available to update the Image component value
509
- // For now, just inform the user.
510
- };
511
- return; // Don't process as image paste if URL is found
512
- }
513
- }
514
-
515
- var items = evt.clipboardData.items;
516
- for (var i = 0; i < items.length; i++) {
517
- if (items[i].type.indexOf("image") === 0) {
518
- var blob = items[i].getAsFile();
519
- var reader = new FileReader();
520
- reader.onload = function(event){
521
- // Update the Gradio Image component source directly
522
- const imgElement = document.querySelector('#input-image img'); // Find the img tag inside the component
523
- if (imgElement) {
524
- imgElement.src = event.target.result;
525
- // We still need to pass the blob to the Gradio function
526
- // Use Gradio's JS API or hidden components if possible
527
- // For now, let's use a simple alert and rely on manual trigger
528
- alert("Image pasted. The preview should update. Please press 'Predict'.");
529
- // Trigger paste_image function - requires Gradio JS interaction
530
- // This part is tricky without official Gradio JS API for updates
531
- }
532
- };
533
- reader.readAsDataURL(blob);
534
- // Prevent default paste handling
535
- evt.preventDefault();
536
- break;
537
- }
538
- }
539
- }
540
-
541
- document.addEventListener('paste', paste_update);
542
- """
543
-
544
- with gr.Blocks(css=css, js=js) as demo:
545
- gr.Markdown("# WD EVA02 LoRA ONNX Tagger")
546
- gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (ONNX format).")
547
- gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
548
-
549
- with gr.Row():
550
- with gr.Column(scale=1):
551
- # Use elem_id for JS targeting
552
- image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
553
- # Container for URL paste message
554
- gr.HTML("<div id='url-input-container'></div>")
555
-
556
- gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
557
- char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
558
- output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
559
- predict_button = gr.Button("Predict", variant="primary")
560
-
561
- with gr.Column(scale=1):
562
- output_tags = gr.Textbox(label="Predicted Tags", lines=10)
563
- output_visualization = gr.Image(type="pil", label="Prediction Visualization")
564
-
565
- # Examples
566
- gr.Examples(
567
- examples=[
568
- ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
569
- ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.5, 0.5, "Tags Only"],
570
- ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpgg", 0.55, 0.5, "Tags + Visualization"],
571
- ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.45, 0.45, "Tags + Visualization"]
572
- ],
573
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
574
- outputs=[output_tags, output_visualization],
575
- fn=predict,
576
- cache_examples=False # Slows down startup if True and large examples
577
- )
578
-
579
- predict_button.click(
580
- fn=predict,
581
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
582
- outputs=[output_tags, output_visualization]
583
- )
584
-
585
- # Add listener for image input changes (e.g., from pasting)
586
- # This might trigger prediction automatically or require the button click
587
- # image_input.change(
588
- # fn=predict,
589
- # inputs=[image_input, gen_threshold, char_threshold, output_mode],
590
- # outputs=[output_tags, output_visualization]
591
- # )
592
-
593
-
594
- if __name__ == "__main__":
595
- # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
596
- if not os.environ.get("HF_TOKEN"):
597
- print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
598
- # Initialize model on startup to avoid delay on first prediction
599
- # initialize_model() # Removed startup initialization
600
  demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import json
7
+ import os
8
+ import io
9
+ import requests
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib
12
+ from huggingface_hub import hf_hub_download
13
+ from dataclasses import dataclass
14
+ from typing import List, Dict, Optional, Tuple
15
+
16
+ # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
17
+ matplotlib.use('Agg')
18
+
19
+ # --- onnx_predict.pyからの移植 ---
20
+
21
+ @dataclass
22
+ class LabelData:
23
+ names: list[str]
24
+ rating: list[np.int64]
25
+ general: list[np.int64]
26
+ artist: list[np.int64]
27
+ character: list[np.int64]
28
+ copyright: list[np.int64]
29
+ meta: list[np.int64]
30
+ quality: list[np.int64]
31
+
32
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
33
+ if image.mode not in ["RGB", "RGBA"]:
34
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
35
+ if image.mode == "RGBA":
36
+ background = Image.new("RGB", image.size, (255, 255, 255))
37
+ background.paste(image, mask=image.split()[3])
38
+ image = background
39
+ return image
40
+
41
+ def pil_pad_square(image: Image.Image) -> Image.Image:
42
+ width, height = image.size
43
+ if width == height:
44
+ return image
45
+ new_size = max(width, height)
46
+ new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
47
+ paste_position = ((new_size - width) // 2, (new_size - height) // 2)
48
+ new_image.paste(image, paste_position)
49
+ return new_image
50
+
51
+ def load_tag_mapping(mapping_path):
52
+ with open(mapping_path, 'r', encoding='utf-8') as f:
53
+ tag_mapping_data = json.load(f)
54
+
55
+ # 新旧フォーマット対応
56
+ if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
57
+ # 旧フォーマット (辞書の中にidx_to_tagとtag_to_categoryがある)
58
+ idx_to_tag_dict = tag_mapping_data["idx_to_tag"]
59
+ tag_to_category_dict = tag_mapping_data["tag_to_category"]
60
+ # tag_mapping_dataが文字列キーになっている可能性があるのでintに変換
61
+ idx_to_tag = {int(k): v for k, v in idx_to_tag_dict.items()}
62
+ tag_to_category = tag_to_category_dict
63
+ elif isinstance(tag_mapping_data, dict):
64
+ # 新フォーマット (キーがインデックスの辞書)
65
+ tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
66
+ idx_to_tag = {}
67
+ tag_to_category = {}
68
+ for idx, data in tag_mapping_data.items():
69
+ tag = data['tag']
70
+ category = data['category']
71
+ idx_to_tag[idx] = tag
72
+ tag_to_category[tag] = category
73
+ else:
74
+ raise ValueError("Unsupported tag mapping format")
75
+
76
+
77
+ names = [None] * (max(idx_to_tag.keys()) + 1)
78
+ rating = []
79
+ general = []
80
+ artist = []
81
+ character = []
82
+ copyright = []
83
+ meta = []
84
+ quality = []
85
+
86
+ for idx, tag in idx_to_tag.items():
87
+ if idx >= len(names): # namesリストのサイズが足りない場合拡張
88
+ names.extend([None] * (idx - len(names) + 1))
89
+ names[idx] = tag
90
+ category = tag_to_category.get(tag, 'Unknown') # カテゴリが見つからない場合
91
+
92
+ if category == 'Rating':
93
+ rating.append(idx)
94
+ elif category == 'General':
95
+ general.append(idx)
96
+ elif category == 'Artist':
97
+ artist.append(idx)
98
+ elif category == 'Character':
99
+ character.append(idx)
100
+ elif category == 'Copyright':
101
+ copyright.append(idx)
102
+ elif category == 'Meta':
103
+ meta.append(idx)
104
+ elif category == 'Quality':
105
+ quality.append(idx)
106
+ # Unknownカテゴリは無視
107
+
108
+ label_data = LabelData(
109
+ names=names,
110
+ rating=np.array(rating, dtype=np.int64),
111
+ general=np.array(general, dtype=np.int64),
112
+ artist=np.array(artist, dtype=np.int64),
113
+ character=np.array(character, dtype=np.int64),
114
+ copyright=np.array(copyright, dtype=np.int64),
115
+ meta=np.array(meta, dtype=np.int64),
116
+ quality=np.array(quality, dtype=np.int64)
117
+ )
118
+
119
+ return label_data, idx_to_tag, tag_to_category
120
+
121
+
122
+ def preprocess_image(image: Image.Image, target_size=(448, 448)):
123
+ image = pil_ensure_rgb(image)
124
+ image = pil_pad_square(image)
125
+ image_resized = image.resize(target_size, Image.BICUBIC)
126
+ img_array = np.array(image_resized, dtype=np.float32) / 255.0
127
+ img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
128
+ # RGB -> BGR (モデルがBGRを期待する場合 - WD Tagger v3はBGR)
129
+ # WD Tagger V2/V1はRGBなので注意
130
+ img_array = img_array[::-1, :, :]
131
+ mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
132
+ std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
133
+ img_array = (img_array - mean) / std
134
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
135
+ return image, img_array # Return original PIL image and processed numpy array
136
+
137
+ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
138
+ result = {
139
+ "rating": [], "general": [], "character": [],
140
+ "copyright": [], "artist": [], "meta": [], "quality": []
141
+ }
142
+
143
+ # Rating (select the max)
144
+ if labels.rating.size > 0:
145
+ rating_probs = probs[labels.rating]
146
+ if rating_probs.size > 0:
147
+ rating_idx = np.argmax(rating_probs)
148
+ # Check if the index is valid for names list
149
+ if labels.rating[rating_idx] < len(labels.names):
150
+ rating_name = labels.names[labels.rating[rating_idx]]
151
+ rating_conf = float(rating_probs[rating_idx])
152
+ result["rating"].append((rating_name, rating_conf))
153
+ else:
154
+ print(f"Warning: Rating index {labels.rating[rating_idx]} out of bounds for names list (size {len(labels.names)}).")
155
+
156
+
157
+ # Quality (select the max)
158
+ if labels.quality.size > 0:
159
+ quality_probs = probs[labels.quality]
160
+ if quality_probs.size > 0:
161
+ quality_idx = np.argmax(quality_probs)
162
+ if labels.quality[quality_idx] < len(labels.names):
163
+ quality_name = labels.names[labels.quality[quality_idx]]
164
+ quality_conf = float(quality_probs[quality_idx])
165
+ result["quality"].append((quality_name, quality_conf))
166
+ else:
167
+ print(f"Warning: Quality index {labels.quality[quality_idx]} out of bounds for names list (size {len(labels.names)}).")
168
+
169
+
170
+ category_map = {
171
+ "general": (labels.general, gen_threshold),
172
+ "character": (labels.character, char_threshold),
173
+ "copyright": (labels.copyright, char_threshold),
174
+ "artist": (labels.artist, char_threshold),
175
+ "meta": (labels.meta, gen_threshold)
176
+ }
177
+
178
+ for category, (indices, threshold) in category_map.items():
179
+ if indices.size > 0:
180
+ # Filter indices to be within the bounds of probs and labels.names
181
+ valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))]
182
+ if valid_indices.size > 0:
183
+ category_probs = probs[valid_indices]
184
+ mask = category_probs >= threshold
185
+ selected_indices = valid_indices[mask]
186
+ selected_probs = category_probs[mask]
187
+ for idx, prob in zip(selected_indices, selected_probs):
188
+ result[category].append((labels.names[idx], float(prob)))
189
+
190
+
191
+ # Sort by probability
192
+ for k in result:
193
+ result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
194
+
195
+ return result
196
+
197
+ def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
198
+ # Filter out unwanted meta tags
199
+ filtered_meta = []
200
+ excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
201
+ for tag, prob in predictions["meta"]:
202
+ if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
203
+ filtered_meta.append((tag, prob))
204
+ predictions["meta"] = filtered_meta # Replace with filtered
205
+
206
+ # Create plot
207
+ fig = plt.figure(figsize=(20, 12), dpi=100)
208
+ gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
209
+ ax_img = fig.add_subplot(gs[0, 0])
210
+ ax_img.imshow(image)
211
+ ax_img.set_title("Original Image")
212
+ ax_img.axis('off')
213
+ ax_tags = fig.add_subplot(gs[0, 1])
214
+
215
+ all_tags = []
216
+ all_probs = []
217
+ all_colors = []
218
+ color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple',
219
+ 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'}
220
+
221
+ for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'),
222
+ ('copyright', '©', 'purple'), ('artist', 'A', 'orange'),
223
+ ('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]:
224
+ for tag, prob in predictions[cat]:
225
+ all_tags.append(f"[{prefix}] {tag}")
226
+ all_probs.append(prob)
227
+ all_colors.append(color)
228
+
229
+ if not all_tags:
230
+ ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
231
+ ax_tags.set_title(f"Tags (threshold={threshold})")
232
+ ax_tags.axis('off')
233
+ plt.tight_layout()
234
+ # Save figure to a BytesIO object
235
+ buf = io.BytesIO()
236
+ plt.savefig(buf, format='png', dpi=100)
237
+ plt.close(fig)
238
+ buf.seek(0)
239
+ return Image.open(buf)
240
+
241
+
242
+ sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i], reverse=True)
243
+ all_tags = [all_tags[i] for i in sorted_indices]
244
+ all_probs = [all_probs[i] for i in sorted_indices]
245
+ all_colors = [all_colors[i] for i in sorted_indices]
246
+
247
+ all_tags.reverse()
248
+ all_probs.reverse()
249
+ all_colors.reverse()
250
+
251
+ num_tags = len(all_tags)
252
+ bar_height = 0.8
253
+ if num_tags > 30: bar_height = 0.8 * (30 / num_tags)
254
+ y_positions = np.arange(num_tags)
255
+
256
+ bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
257
+ ax_tags.set_yticks(y_positions)
258
+ ax_tags.set_yticklabels(all_tags)
259
+
260
+ fontsize = 10
261
+ if num_tags > 40: fontsize = 8
262
+ elif num_tags > 60: fontsize = 6
263
+ for label in ax_tags.get_yticklabels(): label.set_fontsize(fontsize)
264
+
265
+ for i, (bar, prob) in enumerate(zip(bars, all_probs)):
266
+ ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}",
267
+ va='center', fontsize=fontsize)
268
+
269
+ ax_tags.set_xlim(0, 1)
270
+ ax_tags.set_title(f"Tags (threshold={threshold})")
271
+
272
+ from matplotlib.patches import Patch
273
+ legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items()]
274
+ ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
275
+
276
+ plt.tight_layout()
277
+ plt.subplots_adjust(bottom=0.05)
278
+
279
+ # Save figure to a BytesIO object
280
+ buf = io.BytesIO()
281
+ plt.savefig(buf, format='png', dpi=100)
282
+ plt.close(fig)
283
+ buf.seek(0)
284
+ return Image.open(buf)
285
+
286
+ # --- Gradio App Logic ---
287
+
288
+ # 定数
289
+ REPO_ID = "cella110n/cl_tagger"
290
+ # MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model_optimized.onnx"
291
+ MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx" # Use non-optimized if needed
292
+ TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
293
+ CACHE_DIR = "./model_cache"
294
+
295
+ # グローバル変数(モデルとラベルをキャッシュ)
296
+ onnx_session = None
297
+ labels_data = None
298
+ tag_to_category_map = None
299
+
300
+ def download_model_files():
301
+ """Hugging Face Hubからモデルとタグマッピングをダウンロード"""
302
+ print("Downloading model files...")
303
+ # 環境変数からHFトークンを取得 (プライベートリポジトリ用)
304
+ hf_token = os.environ.get("HF_TOKEN")
305
+ try:
306
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
307
+ tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
308
+ print(f"Model downloaded to: {model_path}")
309
+ print(f"Tag mapping downloaded to: {tag_mapping_path}")
310
+ return model_path, tag_mapping_path
311
+ except Exception as e:
312
+ print(f"Error downloading files: {e}")
313
+ # トークンがない場合のエラーメッセージを改善
314
+ if "401 Client Error" in str(e) or "Repository not found" in str(e):
315
+ raise gr.Error(f"Could not download files from {REPO_ID}. "
316
+ f"If this is a private repository, make sure to set the HF_TOKEN secret in your Space settings.")
317
+ else:
318
+ raise gr.Error(f"Error downloading files: {e}")
319
+
320
+
321
+ def initialize_model():
322
+ """モデルとラベルデータを初期化(キャッシュ)"""
323
+ global onnx_session, labels_data, tag_to_category_map
324
+ if onnx_session is None:
325
+ model_path, tag_mapping_path = download_model_files()
326
+ print("Loading model and labels...")
327
+
328
+ # --- Added Logging ---
329
+ print("--- Environment Check ---")
330
+ try:
331
+ import torch
332
+ print(f"PyTorch version: {torch.__version__}")
333
+ if torch.cuda.is_available():
334
+ print(f"PyTorch CUDA available: True")
335
+ print(f"PyTorch CUDA version: {torch.version.cuda}")
336
+ print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
337
+ if torch.backends.cudnn.is_available():
338
+ print(f"PyTorch cuDNN available: True")
339
+ print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}")
340
+ else:
341
+ print("PyTorch cuDNN available: False")
342
+ else:
343
+ print("PyTorch CUDA available: False")
344
+ except ImportError:
345
+ print("PyTorch not found.")
346
+ except Exception as e:
347
+ print(f"Error during PyTorch check: {e}")
348
+
349
+ try:
350
+ print(f"ONNX Runtime build info: {ort.get_buildinfo()}")
351
+ except Exception as e:
352
+ print(f"Error getting ONNX Runtime build info: {e}")
353
+ print("-------------------------")
354
+ # --- End Added Logging ---
355
+
356
+ # ONNXセッションの初期化 (GPU優先)
357
+ available_providers = ort.get_available_providers()
358
+ print(f"Available ONNX Runtime providers: {available_providers}")
359
+ providers = []
360
+ if 'CUDAExecutionProvider' in available_providers:
361
+ providers.append('CUDAExecutionProvider')
362
+ # elif 'DmlExecutionProvider' in available_providers: # DirectML (Windows)
363
+ # providers.append('DmlExecutionProvider')
364
+ providers.append('CPUExecutionProvider') # Always include CPU as fallback
365
+
366
+ try:
367
+ onnx_session = ort.InferenceSession(model_path, providers=providers)
368
+ print(f"Using ONNX Runtime provider: {onnx_session.get_providers()[0]}")
369
+ except Exception as e:
370
+ print(f"Error initializing ONNX session with providers {providers}: {e}")
371
+ print("Falling back to CPUExecutionProvider only.")
372
+ onnx_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
373
+
374
+ labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
375
+ print("Model and labels loaded.")
376
+
377
+ @spaces.GPU()
378
+ def predict(image_input, gen_threshold, char_threshold, output_mode):
379
+ """Gradioインターフェース用の予測関数"""
380
+ initialize_model() # モデルがロードされていなければロード
381
+
382
+ if image_input is None:
383
+ return "Please upload an image.", None
384
+
385
+ print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
386
+
387
+ # PIL Imageオブジェクトであることを確認
388
+ if not isinstance(image_input, Image.Image):
389
+ try:
390
+ # URLの場合
391
+ if isinstance(image_input, str) and image_input.startswith("http"):
392
+ response = requests.get(image_input)
393
+ response.raise_for_status()
394
+ image = Image.open(io.BytesIO(response.content))
395
+ # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
396
+ elif isinstance(image_input, str) and os.path.exists(image_input):
397
+ image = Image.open(image_input)
398
+ # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
399
+ elif isinstance(image_input, np.ndarray):
400
+ image = Image.fromarray(image_input)
401
+ else:
402
+ raise ValueError("Unsupported image input type")
403
+ except Exception as e:
404
+ print(f"Error loading image: {e}")
405
+ return f"Error loading image: {e}", None
406
+ else:
407
+ image = image_input
408
+
409
+
410
+ # 前処理
411
+ original_pil_image, input_data = preprocess_image(image)
412
+
413
+ # データ型をモデルの期待に合わせる (通常はfloat32)
414
+ input_name = onnx_session.get_inputs()[0].name
415
+ expected_type = onnx_session.get_inputs()[0].type
416
+ if expected_type == 'tensor(float16)':
417
+ input_data = input_data.astype(np.float16)
418
+ else:
419
+ input_data = input_data.astype(np.float32) # Default to float32
420
+
421
+ # 推論
422
+ start_time = time.time()
423
+ outputs = onnx_session.run(None, {input_name: input_data})[0]
424
+ inference_time = time.time() - start_time
425
+ print(f"Inference completed in {inference_time:.3f} seconds")
426
+
427
+ # シグモイド関数で確率に変換
428
+ probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
429
+
430
+ # タグ取得
431
+ predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
432
+
433
+ # タグを整形
434
+ output_tags = []
435
+ # RatingとQualityを最初に追加
436
+ if predictions["rating"]:
437
+ output_tags.append(predictions["rating"][0][0].replace("_", " "))
438
+ if predictions["quality"]:
439
+ output_tags.append(predictions["quality"][0][0].replace("_", " "))
440
+
441
+ # 残りのカテゴリをアルファベット順に追加(オプション)
442
+ for category in ["artist", "character", "copyright", "general", "meta"]:
443
+ tags = [tag.replace("_", " ") for tag, prob in predictions[category]
444
+ if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
445
+ output_tags.extend(tags)
446
+
447
+ output_text = ", ".join(output_tags)
448
+
449
+ if output_mode == "Tags Only":
450
+ return output_text, None
451
+ else: # Visualization
452
+ viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
453
+ return output_text, viz_image
454
+
455
+ # --- Gradio Interface Definition ---
456
+ import time
457
+
458
+ # CSS for styling
459
+ css = """
460
+ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
461
+ footer { display: none !important; }
462
+ .gr-prose { max-width: 100% !important; }
463
+ """
464
+ # Custom JS for image pasting and URL handling
465
+ js = """
466
+ async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
467
+ const data = await fetch(blob)
468
+ const image_data = await data.blob()
469
+ const file = new File([image_data], "pasted_image.png",{ type: image_data.type })
470
+ const dt = new DataTransfer()
471
+ dt.items.add(file)
472
+ const element = document.querySelector('#input-image input[type="file"]')
473
+ element.files = dt.files
474
+ // Trigger the change event manually
475
+ const event = new Event('change', { bubbles: true })
476
+ element.dispatchEvent(event)
477
+ // Wait a bit for Gradio to process the change, then trigger predict if needed
478
+ // await new Promise(resolve => setTimeout(resolve, 100)); // Optional delay
479
+ // You might need to manually trigger the prediction or rely on Gradio's auto-triggering
480
+ return [file, gen_thresh, char_thresh, out_mode]; // Return input for Gradio function
481
+ }
482
+
483
+ async function paste_update(evt){
484
+ if (!evt.clipboardData || !evt.clipboardData.items) return;
485
+ var url = evt.clipboardData.getData('text');
486
+ if (url) {
487
+ // Basic check for image URL (you might want a more robust check)
488
+ if (/\.(jpg|jpeg|png|webp|bmp)$/i.test(url)) {
489
+ // Create a button or link to load the URL
490
+ const url_container = document.getElementById('url-input-container');
491
+ url_container.innerHTML = `<p>Detected URL: <button id="load-url-btn" class="gr-button gr-button-sm gr-button-secondary">${url}</button></p>`;
492
+
493
+ document.getElementById('load-url-btn').onclick = async () => {
494
+ // Simulate file upload from URL - Gradio's Image component handles URLs directly
495
+ const element = document.querySelector('#input-image input[type="file"]');
496
+ // Can't directly set URL to file input, so we pass it to Gradio fn
497
+ // Or maybe update the image display src directly if possible?
498
+
499
+ // Let Gradio handle the URL - user needs to click predict
500
+ // We can pre-fill the image component if Gradio supports it via JS,
501
+ // but it's simpler to just let the user click predict after pasting URL.
502
+ alert("URL detected. Please ensure the image input is cleared and then press 'Predict' or re-upload the image.");
503
+ // Clear current image preview if possible?
504
+
505
+ // A workaround: display the URL and let the user manually trigger prediction
506
+ // Or, try to use Gradio's JS API if available to update the Image component value
507
+ // For now, just inform the user.
508
+ };
509
+ return; // Don't process as image paste if URL is found
510
+ }
511
+ }
512
+
513
+ var items = evt.clipboardData.items;
514
+ for (var i = 0; i < items.length; i++) {
515
+ if (items[i].type.indexOf("image") === 0) {
516
+ var blob = items[i].getAsFile();
517
+ var reader = new FileReader();
518
+ reader.onload = function(event){
519
+ // Update the Gradio Image component source directly
520
+ const imgElement = document.querySelector('#input-image img'); // Find the img tag inside the component
521
+ if (imgElement) {
522
+ imgElement.src = event.target.result;
523
+ // We still need to pass the blob to the Gradio function
524
+ // Use Gradio's JS API or hidden components if possible
525
+ // For now, let's use a simple alert and rely on manual trigger
526
+ alert("Image pasted. The preview should update. Please press 'Predict'.");
527
+ // Trigger paste_image function - requires Gradio JS interaction
528
+ // This part is tricky without official Gradio JS API for updates
529
+ }
530
+ };
531
+ reader.readAsDataURL(blob);
532
+ // Prevent default paste handling
533
+ evt.preventDefault();
534
+ break;
535
+ }
536
+ }
537
+ }
538
+
539
+ document.addEventListener('paste', paste_update);
540
+ """
541
+
542
+ with gr.Blocks(css=css, js=js) as demo:
543
+ gr.Markdown("# WD EVA02 LoRA ONNX Tagger")
544
+ gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (ONNX format).")
545
+ gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
546
+
547
+ with gr.Row():
548
+ with gr.Column(scale=1):
549
+ # Use elem_id for JS targeting
550
+ image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
551
+ # Container for URL paste message
552
+ gr.HTML("<div id='url-input-container'></div>")
553
+
554
+ gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
555
+ char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
556
+ output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
557
+ predict_button = gr.Button("Predict", variant="primary")
558
+
559
+ with gr.Column(scale=1):
560
+ output_tags = gr.Textbox(label="Predicted Tags", lines=10)
561
+ output_visualization = gr.Image(type="pil", label="Prediction Visualization")
562
+
563
+ # Examples
564
+ gr.Examples(
565
+ examples=[
566
+ ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
567
+ ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.5, 0.5, "Tags Only"],
568
+ ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", 0.55, 0.5, "Tags + Visualization"],
569
+ ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.45, 0.45, "Tags + Visualization"]
570
+ ],
571
+ inputs=[image_input, gen_threshold, char_threshold, output_mode],
572
+ outputs=[output_tags, output_visualization],
573
+ fn=predict,
574
+ cache_examples=False # Slows down startup if True and large examples
575
+ )
576
+
577
+ predict_button.click(
578
+ fn=predict,
579
+ inputs=[image_input, gen_threshold, char_threshold, output_mode],
580
+ outputs=[output_tags, output_visualization]
581
+ )
582
+
583
+ # Add listener for image input changes (e.g., from pasting)
584
+ # This might trigger prediction automatically or require the button click
585
+ # image_input.change(
586
+ # fn=predict,
587
+ # inputs=[image_input, gen_threshold, char_threshold, output_mode],
588
+ # outputs=[output_tags, output_visualization]
589
+ # )
590
+
591
+
592
+ if __name__ == "__main__":
593
+ # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
594
+ if not os.environ.get("HF_TOKEN"):
595
+ print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
596
+ # Initialize model on startup to avoid delay on first prediction
597
+ # initialize_model() # Removed startup initialization
 
 
598
  demo.launch()
requirements.txt CHANGED
@@ -1,10 +1,11 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu118
2
- torch
3
- torchvision
4
- torchaudio
5
- onnxruntime-gpu==1.17 # または onnxruntime>=1.16.0 (CPUのみの場合)
6
- numpy
7
- Pillow
8
- matplotlib
9
- requests
 
10
  huggingface_hub
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ onnxruntime-gpu==1.18.0 # または onnxruntime>=1.16.0 (CPUのみの場合)
6
+ numpy
7
+ Pillow
8
+ matplotlib
9
+ requests
10
+ gradio>=4.44.0
11
  huggingface_hub