Upload 3 files
Browse files- README.md +31 -35
- app.py +5 -2
- requirements.txt +1 -1
README.md
CHANGED
@@ -1,35 +1,31 @@
|
|
1 |
-
---
|
2 |
-
title: WD EVA02 LoRA ONNX Tagger
|
3 |
-
emoji: 🖼️
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.43.0 #
|
8 |
-
app_file: app.py
|
9 |
-
license: apache-2.0 #
|
10 |
-
# Hardware
|
11 |
-
#
|
12 |
-
# hardware: cpu-upgrade
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
#
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
**Note:**
|
33 |
-
- This Space uses a model from a **private** repository (`celstk/wd-eva02-lora-onnx`). You might need to duplicate this space and add your Hugging Face token (`HF_TOKEN`) to the Space secrets to allow downloading the model files.
|
34 |
-
- Image pasting behavior might vary across browsers.
|
35 |
-
- If you require GPU acceleration, uncomment the `hardware: cuda-t4-small` line above and ensure the environment has the necessary CUDA libraries compatible with `onnxruntime-gpu`. The current setup defaults to CPU due to potential CUDA library mismatches in the standard Spaces environment.
|
|
|
1 |
+
---
|
2 |
+
title: WD EVA02 LoRA ONNX Tagger
|
3 |
+
emoji: 🖼️
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.43.0 # requirements.txt と合わせるか確認
|
8 |
+
app_file: app.py
|
9 |
+
license: apache-2.0 # または適切なライセンス
|
10 |
+
# Pinned Hardware: T4 small (GPU) or CPU upgrade (CPU)
|
11 |
+
# pinned: false # 必要に応じてTrueに
|
12 |
+
# hardware: cpu-upgrade # or cuda-t4-small
|
13 |
+
# hf_token: YOUR_HF_TOKEN # Use secrets instead!
|
14 |
+
---
|
15 |
+
|
16 |
+
# WD EVA02 LoRA ONNX Tagger
|
17 |
+
|
18 |
+
This Space demonstrates image tagging using a fine-tuned WD EVA02 model (converted to ONNX format).
|
19 |
+
|
20 |
+
Model Repository: [celstk/wd-eva02-lora-onnx](https://huggingface.co/celstk/wd-eva02-lora-onnx)
|
21 |
+
|
22 |
+
**How to Use:**
|
23 |
+
1. Upload an image using the upload button.
|
24 |
+
2. Alternatively, paste an image URL into the browser (experimental paste handling).
|
25 |
+
3. Adjust the tag thresholds if needed.
|
26 |
+
4. Choose the output mode (Tags only or include visualization).
|
27 |
+
5. Click the "Predict" button.
|
28 |
+
|
29 |
+
**Note:**
|
30 |
+
- This Space uses a model from a **private** repository (`celstk/wd-eva02-lora-onnx`). You might need to duplicate this space and add your Hugging Face token (`HF_TOKEN`) to the Space secrets to allow downloading the model files.
|
31 |
+
- Image pasting behavior might vary across browsers.
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
import numpy as np
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
5 |
import json
|
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download
|
|
12 |
from dataclasses import dataclass
|
13 |
from typing import List, Dict, Optional, Tuple
|
14 |
import time
|
|
|
15 |
|
16 |
import torch
|
17 |
import timm
|
@@ -347,9 +348,11 @@ def initialize_labels_and_paths():
|
|
347 |
print(f"Tag mapping file not found after download attempt: {tag_mapping_path_global}")
|
348 |
raise gr.Error("Tag mapping file could not be downloaded or found.")
|
349 |
|
350 |
-
|
|
|
351 |
def predict(image_input, gen_threshold, char_threshold, output_mode):
|
352 |
print("--- predict function started (GPU worker) ---")
|
|
|
353 |
initialize_labels_and_paths()
|
354 |
print("Loading PyTorch model...")
|
355 |
global safetensors_path_global, labels_data
|
|
|
1 |
import gradio as gr
|
2 |
+
# import onnxruntime as ort # Removed
|
3 |
import numpy as np
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
5 |
import json
|
|
|
12 |
from dataclasses import dataclass
|
13 |
from typing import List, Dict, Optional, Tuple
|
14 |
import time
|
15 |
+
# import spaces # Keep for @spaces.GPU
|
16 |
|
17 |
import torch
|
18 |
import timm
|
|
|
348 |
print(f"Tag mapping file not found after download attempt: {tag_mapping_path_global}")
|
349 |
raise gr.Error("Tag mapping file could not be downloaded or found.")
|
350 |
|
351 |
+
# --- Prediction Function (PyTorch based) ---
|
352 |
+
# @spaces.GPU() # Removed decorator
|
353 |
def predict(image_input, gen_threshold, char_threshold, output_mode):
|
354 |
print("--- predict function started (GPU worker) ---")
|
355 |
+
"""Gradioインターフェース用の予測関数 (PyTorch GPUワーカー内)"""
|
356 |
initialize_labels_and_paths()
|
357 |
print("Loading PyTorch model...")
|
358 |
global safetensors_path_global, labels_data
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ torchaudio
|
|
6 |
safetensors
|
7 |
transformers
|
8 |
timm # Needed for EVA02 base model
|
9 |
-
numpy #
|
10 |
Pillow
|
11 |
matplotlib
|
12 |
requests
|
|
|
6 |
safetensors
|
7 |
transformers
|
8 |
timm # Needed for EVA02 base model
|
9 |
+
numpy # Let pip resolve NumPy version
|
10 |
Pillow
|
11 |
matplotlib
|
12 |
requests
|