cella110n commited on
Commit
1d5f04c
·
verified ·
1 Parent(s): d4905b8

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +31 -13
  2. app.py +567 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,13 +1,31 @@
1
- ---
2
- title: Cl Tagger
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.27.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: tagger finetuned from wdtagger
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: WD EVA02 LoRA ONNX Tagger
3
+ emoji: 🖼️
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.28.3 # requirements.txt と合わせるか確認
8
+ app_file: app.py
9
+ license: apache-2.0 # または適切なライセンス
10
+ # Pinned Hardware: T4 small (GPU) or CPU upgrade (CPU)
11
+ # pinned: false # 必要に応じてTrueに
12
+ # hardware: cpu-upgrade # or cuda-t4-small
13
+ # hf_token: YOUR_HF_TOKEN # Use secrets instead!
14
+ ---
15
+
16
+ # WD EVA02 LoRA ONNX Tagger
17
+
18
+ This Space demonstrates image tagging using a fine-tuned WD EVA02 model (converted to ONNX format).
19
+
20
+ Model Repository: [celstk/wd-eva02-lora-onnx](https://huggingface.co/celstk/wd-eva02-lora-onnx)
21
+
22
+ **How to Use:**
23
+ 1. Upload an image using the upload button.
24
+ 2. Alternatively, paste an image URL into the browser (experimental paste handling).
25
+ 3. Adjust the tag thresholds if needed.
26
+ 4. Choose the output mode (Tags only or include visualization).
27
+ 5. Click the "Predict" button.
28
+
29
+ **Note:**
30
+ - This Space uses a model from a **private** repository (`celstk/wd-eva02-lora-onnx`). You might need to duplicate this space and add your Hugging Face token (`HF_TOKEN`) to the Space secrets to allow downloading the model files.
31
+ - Image pasting behavior might vary across browsers.
app.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import json
6
+ import os
7
+ import io
8
+ import requests
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
+ from huggingface_hub import hf_hub_download
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Optional, Tuple
14
+
15
+ # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
16
+ matplotlib.use('Agg')
17
+
18
+ # --- onnx_predict.pyからの移植 ---
19
+
20
+ @dataclass
21
+ class LabelData:
22
+ names: list[str]
23
+ rating: list[np.int64]
24
+ general: list[np.int64]
25
+ artist: list[np.int64]
26
+ character: list[np.int64]
27
+ copyright: list[np.int64]
28
+ meta: list[np.int64]
29
+ quality: list[np.int64]
30
+
31
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
32
+ if image.mode not in ["RGB", "RGBA"]:
33
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
34
+ if image.mode == "RGBA":
35
+ background = Image.new("RGB", image.size, (255, 255, 255))
36
+ background.paste(image, mask=image.split()[3])
37
+ image = background
38
+ return image
39
+
40
+ def pil_pad_square(image: Image.Image) -> Image.Image:
41
+ width, height = image.size
42
+ if width == height:
43
+ return image
44
+ new_size = max(width, height)
45
+ new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
46
+ paste_position = ((new_size - width) // 2, (new_size - height) // 2)
47
+ new_image.paste(image, paste_position)
48
+ return new_image
49
+
50
+ def load_tag_mapping(mapping_path):
51
+ with open(mapping_path, 'r', encoding='utf-8') as f:
52
+ tag_mapping_data = json.load(f)
53
+
54
+ # 新旧フォーマット対応
55
+ if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
56
+ # 旧フォーマット (辞書の中にidx_to_tagとtag_to_categoryがある)
57
+ idx_to_tag_dict = tag_mapping_data["idx_to_tag"]
58
+ tag_to_category_dict = tag_mapping_data["tag_to_category"]
59
+ # tag_mapping_dataが文字列キーになっている可能性があるのでintに変換
60
+ idx_to_tag = {int(k): v for k, v in idx_to_tag_dict.items()}
61
+ tag_to_category = tag_to_category_dict
62
+ elif isinstance(tag_mapping_data, dict):
63
+ # 新フォーマット (キーがインデックスの辞書)
64
+ tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
65
+ idx_to_tag = {}
66
+ tag_to_category = {}
67
+ for idx, data in tag_mapping_data.items():
68
+ tag = data['tag']
69
+ category = data['category']
70
+ idx_to_tag[idx] = tag
71
+ tag_to_category[tag] = category
72
+ else:
73
+ raise ValueError("Unsupported tag mapping format")
74
+
75
+
76
+ names = [None] * (max(idx_to_tag.keys()) + 1)
77
+ rating = []
78
+ general = []
79
+ artist = []
80
+ character = []
81
+ copyright = []
82
+ meta = []
83
+ quality = []
84
+
85
+ for idx, tag in idx_to_tag.items():
86
+ if idx >= len(names): # namesリストのサイズが足りない場合拡張
87
+ names.extend([None] * (idx - len(names) + 1))
88
+ names[idx] = tag
89
+ category = tag_to_category.get(tag, 'Unknown') # カテゴリが見つからない場合
90
+
91
+ if category == 'Rating':
92
+ rating.append(idx)
93
+ elif category == 'General':
94
+ general.append(idx)
95
+ elif category == 'Artist':
96
+ artist.append(idx)
97
+ elif category == 'Character':
98
+ character.append(idx)
99
+ elif category == 'Copyright':
100
+ copyright.append(idx)
101
+ elif category == 'Meta':
102
+ meta.append(idx)
103
+ elif category == 'Quality':
104
+ quality.append(idx)
105
+ # Unknownカテゴリは無視
106
+
107
+ label_data = LabelData(
108
+ names=names,
109
+ rating=np.array(rating, dtype=np.int64),
110
+ general=np.array(general, dtype=np.int64),
111
+ artist=np.array(artist, dtype=np.int64),
112
+ character=np.array(character, dtype=np.int64),
113
+ copyright=np.array(copyright, dtype=np.int64),
114
+ meta=np.array(meta, dtype=np.int64),
115
+ quality=np.array(quality, dtype=np.int64)
116
+ )
117
+
118
+ return label_data, idx_to_tag, tag_to_category
119
+
120
+
121
+ def preprocess_image(image: Image.Image, target_size=(448, 448)):
122
+ image = pil_ensure_rgb(image)
123
+ image = pil_pad_square(image)
124
+ image_resized = image.resize(target_size, Image.BICUBIC)
125
+ img_array = np.array(image_resized, dtype=np.float32) / 255.0
126
+ img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
127
+ # RGB -> BGR (モデルがBGRを期待する場合 - WD Tagger v3はBGR)
128
+ # WD Tagger V2/V1はRGBなので注意
129
+ img_array = img_array[::-1, :, :]
130
+ mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
131
+ std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
132
+ img_array = (img_array - mean) / std
133
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
134
+ return image, img_array # Return original PIL image and processed numpy array
135
+
136
+ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
137
+ result = {
138
+ "rating": [], "general": [], "character": [],
139
+ "copyright": [], "artist": [], "meta": [], "quality": []
140
+ }
141
+
142
+ # Rating (select the max)
143
+ if labels.rating.size > 0:
144
+ rating_probs = probs[labels.rating]
145
+ if rating_probs.size > 0:
146
+ rating_idx = np.argmax(rating_probs)
147
+ # Check if the index is valid for names list
148
+ if labels.rating[rating_idx] < len(labels.names):
149
+ rating_name = labels.names[labels.rating[rating_idx]]
150
+ rating_conf = float(rating_probs[rating_idx])
151
+ result["rating"].append((rating_name, rating_conf))
152
+ else:
153
+ print(f"Warning: Rating index {labels.rating[rating_idx]} out of bounds for names list (size {len(labels.names)}).")
154
+
155
+
156
+ # Quality (select the max)
157
+ if labels.quality.size > 0:
158
+ quality_probs = probs[labels.quality]
159
+ if quality_probs.size > 0:
160
+ quality_idx = np.argmax(quality_probs)
161
+ if labels.quality[quality_idx] < len(labels.names):
162
+ quality_name = labels.names[labels.quality[quality_idx]]
163
+ quality_conf = float(quality_probs[quality_idx])
164
+ result["quality"].append((quality_name, quality_conf))
165
+ else:
166
+ print(f"Warning: Quality index {labels.quality[quality_idx]} out of bounds for names list (size {len(labels.names)}).")
167
+
168
+
169
+ category_map = {
170
+ "general": (labels.general, gen_threshold),
171
+ "character": (labels.character, char_threshold),
172
+ "copyright": (labels.copyright, char_threshold),
173
+ "artist": (labels.artist, char_threshold),
174
+ "meta": (labels.meta, gen_threshold)
175
+ }
176
+
177
+ for category, (indices, threshold) in category_map.items():
178
+ if indices.size > 0:
179
+ # Filter indices to be within the bounds of probs and labels.names
180
+ valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))]
181
+ if valid_indices.size > 0:
182
+ category_probs = probs[valid_indices]
183
+ mask = category_probs >= threshold
184
+ selected_indices = valid_indices[mask]
185
+ selected_probs = category_probs[mask]
186
+ for idx, prob in zip(selected_indices, selected_probs):
187
+ result[category].append((labels.names[idx], float(prob)))
188
+
189
+
190
+ # Sort by probability
191
+ for k in result:
192
+ result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
193
+
194
+ return result
195
+
196
+ def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
197
+ # Filter out unwanted meta tags
198
+ filtered_meta = []
199
+ excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
200
+ for tag, prob in predictions["meta"]:
201
+ if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
202
+ filtered_meta.append((tag, prob))
203
+ predictions["meta"] = filtered_meta # Replace with filtered
204
+
205
+ # Create plot
206
+ fig = plt.figure(figsize=(20, 12), dpi=100)
207
+ gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
208
+ ax_img = fig.add_subplot(gs[0, 0])
209
+ ax_img.imshow(image)
210
+ ax_img.set_title("Original Image")
211
+ ax_img.axis('off')
212
+ ax_tags = fig.add_subplot(gs[0, 1])
213
+
214
+ all_tags = []
215
+ all_probs = []
216
+ all_colors = []
217
+ color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple',
218
+ 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'}
219
+
220
+ for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'),
221
+ ('copyright', '©', 'purple'), ('artist', 'A', 'orange'),
222
+ ('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]:
223
+ for tag, prob in predictions[cat]:
224
+ all_tags.append(f"[{prefix}] {tag}")
225
+ all_probs.append(prob)
226
+ all_colors.append(color)
227
+
228
+ if not all_tags:
229
+ ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
230
+ ax_tags.set_title(f"Tags (threshold={threshold})")
231
+ ax_tags.axis('off')
232
+ plt.tight_layout()
233
+ # Save figure to a BytesIO object
234
+ buf = io.BytesIO()
235
+ plt.savefig(buf, format='png', dpi=100)
236
+ plt.close(fig)
237
+ buf.seek(0)
238
+ return Image.open(buf)
239
+
240
+
241
+ sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i], reverse=True)
242
+ all_tags = [all_tags[i] for i in sorted_indices]
243
+ all_probs = [all_probs[i] for i in sorted_indices]
244
+ all_colors = [all_colors[i] for i in sorted_indices]
245
+
246
+ all_tags.reverse()
247
+ all_probs.reverse()
248
+ all_colors.reverse()
249
+
250
+ num_tags = len(all_tags)
251
+ bar_height = 0.8
252
+ if num_tags > 30: bar_height = 0.8 * (30 / num_tags)
253
+ y_positions = np.arange(num_tags)
254
+
255
+ bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
256
+ ax_tags.set_yticks(y_positions)
257
+ ax_tags.set_yticklabels(all_tags)
258
+
259
+ fontsize = 10
260
+ if num_tags > 40: fontsize = 8
261
+ elif num_tags > 60: fontsize = 6
262
+ for label in ax_tags.get_yticklabels(): label.set_fontsize(fontsize)
263
+
264
+ for i, (bar, prob) in enumerate(zip(bars, all_probs)):
265
+ ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}",
266
+ va='center', fontsize=fontsize)
267
+
268
+ ax_tags.set_xlim(0, 1)
269
+ ax_tags.set_title(f"Tags (threshold={threshold})")
270
+
271
+ from matplotlib.patches import Patch
272
+ legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items()]
273
+ ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
274
+
275
+ plt.tight_layout()
276
+ plt.subplots_adjust(bottom=0.05)
277
+
278
+ # Save figure to a BytesIO object
279
+ buf = io.BytesIO()
280
+ plt.savefig(buf, format='png', dpi=100)
281
+ plt.close(fig)
282
+ buf.seek(0)
283
+ return Image.open(buf)
284
+
285
+ # --- Gradio App Logic ---
286
+
287
+ # 定数
288
+ REPO_ID = "celstk/wd-eva02-lora-onnx"
289
+ MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model_optimized.onnx"
290
+ # MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx" # Use non-optimized if needed
291
+ TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
292
+ CACHE_DIR = "./model_cache"
293
+
294
+ # グローバル変数(モデルとラベルをキャッシュ)
295
+ onnx_session = None
296
+ labels_data = None
297
+ tag_to_category_map = None
298
+
299
+ def download_model_files():
300
+ """Hugging Face Hubからモデルとタグマッピングをダウンロード"""
301
+ print("Downloading model files...")
302
+ # 環境変数からHFトークンを取得 (プライベートリポジトリ用)
303
+ hf_token = os.environ.get("HF_TOKEN")
304
+ try:
305
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
306
+ tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
307
+ print(f"Model downloaded to: {model_path}")
308
+ print(f"Tag mapping downloaded to: {tag_mapping_path}")
309
+ return model_path, tag_mapping_path
310
+ except Exception as e:
311
+ print(f"Error downloading files: {e}")
312
+ # トークンがない場合のエラーメッセージを改善
313
+ if "401 Client Error" in str(e) or "Repository not found" in str(e):
314
+ raise gr.Error(f"Could not download files from {REPO_ID}. "
315
+ f"If this is a private repository, make sure to set the HF_TOKEN secret in your Space settings.")
316
+ else:
317
+ raise gr.Error(f"Error downloading files: {e}")
318
+
319
+
320
+ def initialize_model():
321
+ """モデルとラベルデータを初期化(キャッシュ)"""
322
+ global onnx_session, labels_data, tag_to_category_map
323
+ if onnx_session is None:
324
+ model_path, tag_mapping_path = download_model_files()
325
+ print("Loading model and labels...")
326
+ # ONNXセッションの初期化 (GPU優先)
327
+ available_providers = ort.get_available_providers()
328
+ print(f"Available ONNX Runtime providers: {available_providers}")
329
+ providers = []
330
+ if 'CUDAExecutionProvider' in available_providers:
331
+ providers.append('CUDAExecutionProvider')
332
+ # elif 'DmlExecutionProvider' in available_providers: # DirectML (Windows)
333
+ # providers.append('DmlExecutionProvider')
334
+ providers.append('CPUExecutionProvider') # Always include CPU as fallback
335
+
336
+ try:
337
+ onnx_session = ort.InferenceSession(model_path, providers=providers)
338
+ print(f"Using ONNX Runtime provider: {onnx_session.get_providers()[0]}")
339
+ except Exception as e:
340
+ print(f"Error initializing ONNX session with providers {providers}: {e}")
341
+ print("Falling back to CPUExecutionProvider only.")
342
+ onnx_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
343
+
344
+ labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
345
+ print("Model and labels loaded.")
346
+
347
+ def predict(image_input, gen_threshold, char_threshold, output_mode):
348
+ """Gradioインターフェース用の予測関数"""
349
+ initialize_model() # モデルがロードされていなければロード
350
+
351
+ if image_input is None:
352
+ return "Please upload an image.", None
353
+
354
+ print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
355
+
356
+ # PIL Imageオブジェクトであることを確認
357
+ if not isinstance(image_input, Image.Image):
358
+ try:
359
+ # URLの場合
360
+ if isinstance(image_input, str) and image_input.startswith("http"):
361
+ response = requests.get(image_input)
362
+ response.raise_for_status()
363
+ image = Image.open(io.BytesIO(response.content))
364
+ # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
365
+ elif isinstance(image_input, str) and os.path.exists(image_input):
366
+ image = Image.open(image_input)
367
+ # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
368
+ elif isinstance(image_input, np.ndarray):
369
+ image = Image.fromarray(image_input)
370
+ else:
371
+ raise ValueError("Unsupported image input type")
372
+ except Exception as e:
373
+ print(f"Error loading image: {e}")
374
+ return f"Error loading image: {e}", None
375
+ else:
376
+ image = image_input
377
+
378
+
379
+ # 前処理
380
+ original_pil_image, input_data = preprocess_image(image)
381
+
382
+ # データ型をモデルの期待に合わせる (通常はfloat32)
383
+ input_name = onnx_session.get_inputs()[0].name
384
+ expected_type = onnx_session.get_inputs()[0].type
385
+ if expected_type == 'tensor(float16)':
386
+ input_data = input_data.astype(np.float16)
387
+ else:
388
+ input_data = input_data.astype(np.float32) # Default to float32
389
+
390
+ # 推論
391
+ start_time = time.time()
392
+ outputs = onnx_session.run(None, {input_name: input_data})[0]
393
+ inference_time = time.time() - start_time
394
+ print(f"Inference completed in {inference_time:.3f} seconds")
395
+
396
+ # シグモイド関数で確率に変換
397
+ probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
398
+
399
+ # タグ取得
400
+ predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
401
+
402
+ # タグを整形
403
+ output_tags = []
404
+ # RatingとQualityを最初に追加
405
+ if predictions["rating"]:
406
+ output_tags.append(predictions["rating"][0][0].replace("_", " "))
407
+ if predictions["quality"]:
408
+ output_tags.append(predictions["quality"][0][0].replace("_", " "))
409
+
410
+ # 残りのカテゴリをアルファベット順に追加(オプション)
411
+ for category in ["artist", "character", "copyright", "general", "meta"]:
412
+ tags = [tag.replace("_", " ") for tag, prob in predictions[category]
413
+ if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
414
+ output_tags.extend(tags)
415
+
416
+ output_text = ", ".join(output_tags)
417
+
418
+ if output_mode == "Tags Only":
419
+ return output_text, None
420
+ else: # Visualization
421
+ viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
422
+ return output_text, viz_image
423
+
424
+ # --- Gradio Interface Definition ---
425
+ import time
426
+
427
+ # CSS for styling
428
+ css = """
429
+ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
430
+ footer { display: none !important; }
431
+ .gr-prose { max-width: 100% !important; }
432
+ """
433
+ # Custom JS for image pasting and URL handling
434
+ js = """
435
+ async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
436
+ const data = await fetch(blob)
437
+ const image_data = await data.blob()
438
+ const file = new File([image_data], "pasted_image.png",{ type: image_data.type })
439
+ const dt = new DataTransfer()
440
+ dt.items.add(file)
441
+ const element = document.querySelector('#input-image input[type="file"]')
442
+ element.files = dt.files
443
+ // Trigger the change event manually
444
+ const event = new Event('change', { bubbles: true })
445
+ element.dispatchEvent(event)
446
+ // Wait a bit for Gradio to process the change, then trigger predict if needed
447
+ // await new Promise(resolve => setTimeout(resolve, 100)); // Optional delay
448
+ // You might need to manually trigger the prediction or rely on Gradio's auto-triggering
449
+ return [file, gen_thresh, char_thresh, out_mode]; // Return input for Gradio function
450
+ }
451
+
452
+ async function paste_update(evt){
453
+ if (!evt.clipboardData || !evt.clipboardData.items) return;
454
+ var url = evt.clipboardData.getData('text');
455
+ if (url) {
456
+ // Basic check for image URL (you might want a more robust check)
457
+ if (/\.(jpg|jpeg|png|webp|bmp)$/i.test(url)) {
458
+ // Create a button or link to load the URL
459
+ const url_container = document.getElementById('url-input-container');
460
+ url_container.innerHTML = `<p>Detected URL: <button id="load-url-btn" class="gr-button gr-button-sm gr-button-secondary">${url}</button></p>`;
461
+
462
+ document.getElementById('load-url-btn').onclick = async () => {
463
+ // Simulate file upload from URL - Gradio's Image component handles URLs directly
464
+ const element = document.querySelector('#input-image input[type="file"]');
465
+ // Can't directly set URL to file input, so we pass it to Gradio fn
466
+ // Or maybe update the image display src directly if possible?
467
+
468
+ // Let Gradio handle the URL - user needs to click predict
469
+ // We can pre-fill the image component if Gradio supports it via JS,
470
+ // but it's simpler to just let the user click predict after pasting URL.
471
+ alert("URL detected. Please ensure the image input is cleared and then press 'Predict' or re-upload the image.");
472
+ // Clear current image preview if possible?
473
+
474
+ // A workaround: display the URL and let the user manually trigger prediction
475
+ // Or, try to use Gradio's JS API if available to update the Image component value
476
+ // For now, just inform the user.
477
+ };
478
+ return; // Don't process as image paste if URL is found
479
+ }
480
+ }
481
+
482
+ var items = evt.clipboardData.items;
483
+ for (var i = 0; i < items.length; i++) {
484
+ if (items[i].type.indexOf("image") === 0) {
485
+ var blob = items[i].getAsFile();
486
+ var reader = new FileReader();
487
+ reader.onload = function(event){
488
+ // Update the Gradio Image component source directly
489
+ const imgElement = document.querySelector('#input-image img'); // Find the img tag inside the component
490
+ if (imgElement) {
491
+ imgElement.src = event.target.result;
492
+ // We still need to pass the blob to the Gradio function
493
+ // Use Gradio's JS API or hidden components if possible
494
+ // For now, let's use a simple alert and rely on manual trigger
495
+ alert("Image pasted. The preview should update. Please press 'Predict'.");
496
+ // Trigger paste_image function - requires Gradio JS interaction
497
+ // This part is tricky without official Gradio JS API for updates
498
+ }
499
+ };
500
+ reader.readAsDataURL(blob);
501
+ // Prevent default paste handling
502
+ evt.preventDefault();
503
+ break;
504
+ }
505
+ }
506
+ }
507
+
508
+ document.addEventListener('paste', paste_update);
509
+ """
510
+
511
+ with gr.Blocks(css=css, js=js) as demo:
512
+ gr.Markdown("# WD EVA02 LoRA ONNX Tagger")
513
+ gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (ONNX format).")
514
+ gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
515
+
516
+ with gr.Row():
517
+ with gr.Column(scale=1):
518
+ # Use elem_id for JS targeting
519
+ image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
520
+ # Container for URL paste message
521
+ gr.HTML("<div id='url-input-container'></div>")
522
+
523
+ gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
524
+ char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
525
+ output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
526
+ predict_button = gr.Button("Predict", variant="primary")
527
+
528
+ with gr.Column(scale=1):
529
+ output_tags = gr.Textbox(label="Predicted Tags", lines=10)
530
+ output_visualization = gr.Image(type="pil", label="Prediction Visualization")
531
+
532
+ # Examples
533
+ gr.Examples(
534
+ examples=[
535
+ ["https://pbs.twimg.com/media/GpiBUQZawAAetgr.jpg", 0.55, 0.5, "Tags + Visualization"],
536
+ ["https://pbs.twimg.com/media/GooBBQHWcAAJj2q.jpg", 0.5, 0.5, "Tags Only"],
537
+ ["https://m.media-amazon.com/images/I/61FwAqFu4PL.jpg", 0.55, 0.5, "Tags + Visualization"],
538
+ ["https://cdn.donmai.us/sample/5d/ad/__kanae_and_kanae_nijisanji_drawn_by_cococall__sample-5dadca17680ef18c18daaf75507c4b12.jpg", 0.45, 0.45, "Tags + Visualization"]
539
+ ],
540
+ inputs=[image_input, gen_threshold, char_threshold, output_mode],
541
+ outputs=[output_tags, output_visualization],
542
+ fn=predict,
543
+ cache_examples=False # Slows down startup if True and large examples
544
+ )
545
+
546
+ predict_button.click(
547
+ fn=predict,
548
+ inputs=[image_input, gen_threshold, char_threshold, output_mode],
549
+ outputs=[output_tags, output_visualization]
550
+ )
551
+
552
+ # Add listener for image input changes (e.g., from pasting)
553
+ # This might trigger prediction automatically or require the button click
554
+ # image_input.change(
555
+ # fn=predict,
556
+ # inputs=[image_input, gen_threshold, char_threshold, output_mode],
557
+ # outputs=[output_tags, output_visualization]
558
+ # )
559
+
560
+
561
+ if __name__ == "__main__":
562
+ # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
563
+ if not os.environ.get("HF_TOKEN"):
564
+ print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
565
+ # Initialize model on startup to avoid delay on first prediction
566
+ initialize_model()
567
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ onnxruntime-gpu>=1.16.0 # または onnxruntime>=1.16.0 (CPUのみの場合)
2
+ numpy
3
+ Pillow
4
+ matplotlib
5
+ requests
6
+ gradio
7
+ huggingface_hub