Alptekinege commited on
Commit
b55b15c
Β·
verified Β·
1 Parent(s): 5538a88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -313
app.py CHANGED
@@ -1,319 +1,129 @@
1
- import argparse
2
- import os
3
- import gradio as gr
4
- import huggingface_hub
5
  import numpy as np
6
- import onnxruntime as rt
7
- import pandas as pd
8
  from PIL import Image
9
- import json # Added for loading metadata.json from the inference file
10
 
11
- TITLE = "WaifuDiffusion Tagger"
12
- DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
-
15
- Example image by [γ»γ—β˜†β˜†β˜†](https://www.pixiv.net/en/users/43565085)
16
- """
17
-
18
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
19
-
20
- # Dataset v3 series of models:
21
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
25
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
26
-
27
- # Dataset v2 series of models:
28
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
29
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
30
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
31
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
32
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
33
-
34
- # IdolSankaku series of models:
35
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
36
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
37
-
38
- # Files to download from the repos
39
- MODEL_FILENAME = "model.onnx"
40
- LABEL_FILENAME = "selected_tags.csv"
41
-
42
- kaomojis = [
43
- "0_0",
44
- "(o)_(o)",
45
- "+_+",
46
- "+_-",
47
- "._.",
48
- "<o>_<o>",
49
- "<|>_<|>",
50
- "=_=",
51
- ">_<",
52
- "3_3",
53
- "6_9",
54
- ">_o",
55
- "@_@",
56
- "^_^",
57
- "o_o",
58
- "u_u",
59
- "x_x",
60
- "|_|",
61
- "||_||",
62
- ]
63
-
64
- def parse_args() -> argparse.Namespace:
65
- parser = argparse.ArgumentParser()
66
- parser.add_argument("--score-slider-step", type=float, default=0.05)
67
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
68
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
69
- return parser.parse_args()
70
-
71
- def load_labels(dataframe) -> list[str]:
72
- name_series = dataframe["name"]
73
- name_series = name_series.map(
74
- lambda x: x.replace("_", " ") if x not in kaomojis else x
75
- )
76
- tag_names = name_series.tolist()
77
-
78
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
79
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
80
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
81
- return tag_names, rating_indexes, general_indexes, character_indexes
82
-
83
- def mcut_threshold(probs):
84
- sorted_probs = probs[probs.argsort()[::-1]]
85
- difs = sorted_probs[:-1] - sorted_probs[1:]
86
- t = difs.argmax()
87
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
88
- return thresh
89
-
90
- class Predictor:
91
- def __init__(self):
92
- self.model_target_size = None
93
- self.last_loaded_repo = None
94
- # Added flag to distinguish between custom and Hugging Face models
95
- self.is_custom_model = False
96
-
97
- def download_model(self, model_repo):
98
- csv_path = huggingface_hub.hf_hub_download(
99
- model_repo,
100
- LABEL_FILENAME,
101
- use_auth_token=HF_TOKEN,
102
- )
103
- model_path = huggingface_hub.hf_hub_download(
104
- model_repo,
105
- MODEL_FILENAME,
106
- use_auth_token=HF_TOKEN,
107
- )
108
- return csv_path, model_path
109
-
110
- def load_model(self, model_repo, onnx_path=None, metadata_path=None):
111
- # Modified to accept onnx_path and metadata_path for custom model support
112
- if model_repo == "Custom Model" and onnx_path and metadata_path:
113
- # Check if the custom model files have already been loaded
114
- if self.last_loaded_repo == (onnx_path, metadata_path):
115
- return
116
- self.is_custom_model = True
117
- # Load the ONNX model from the provided path (from inference file)
118
- self.model = rt.InferenceSession(onnx_path)
119
- # Load metadata from metadata.json (from inference file)
120
- with open(metadata_path, "r", encoding="utf-8") as f:
121
- metadata = json.load(f)
122
- self.idx_to_tag = metadata["idx_to_tag"]
123
- # Create tag_names list from idx_to_tag dictionary
124
- self.tag_names = [self.idx_to_tag[str(i)] for i in range(len(self.idx_to_tag))]
125
- # Set target size to 512 for custom model, as per inference file
126
- self.model_target_size = 512
127
- self.last_loaded_repo = (onnx_path, metadata_path)
128
  else:
129
- # Existing logic for Hugging Face models
130
- self.is_custom_model = False
131
- if self.last_loaded_repo == model_repo:
132
- return
133
- csv_path, model_path = self.download_model(model_repo)
134
- tags_df = pd.read_csv(csv_path)
135
- sep_tags = load_labels(tags_df)
136
- self.tag_names = sep_tags[0]
137
- self.rating_indexes = sep_tags[1]
138
- self.general_indexes = sep_tags[2]
139
- self.character_indexes = sep_tags[3]
140
- self.model = rt.InferenceSession(model_path)
141
- _, height, width, _ = self.model.get_inputs()[0].shape
142
- self.model_target_size = height
143
- self.last_loaded_repo = model_repo
144
-
145
- def prepare_image(self, image):
146
- if self.is_custom_model:
147
- # Added preprocessing logic from inference file's preprocess_image function
148
- # Adapted to take a PIL image instead of a file path
149
- target_size = self.model_target_size
150
- img = image.convert("RGB")
151
- w, h = img.size
152
- aspect = w / h
153
- if aspect > 1:
154
- new_w = target_size
155
- new_h = int(new_w / aspect)
156
- else:
157
- new_h = target_size
158
- new_w = int(new_h * aspect)
159
- img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
160
- background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
161
- paste_x = (target_size - new_w) // 2
162
- paste_y = (target_size - new_h) // 2
163
- background.paste(img, (paste_x, paste_y))
164
- arr = np.array(background).astype("float32") / 255.0
165
- arr = np.transpose(arr, (2, 0, 1)) # HWC to CHW as per inference file
166
- arr = np.expand_dims(arr, axis=0)
167
- return arr
168
- else:
169
- # Existing preprocessing logic for Hugging Face models
170
- target_size = self.model_target_size
171
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
172
- canvas.alpha_composite(image)
173
- image = canvas.convert("RGB")
174
- image_shape = image.size
175
- max_dim = max(image_shape)
176
- pad_left = (max_dim - image_shape[0]) // 2
177
- pad_top = (max_dim - image_shape[1]) // 2
178
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
179
- padded_image.paste(image, (pad_left, pad_top))
180
- if max_dim != target_size:
181
- padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
182
- image_array = np.asarray(padded_image, dtype=np.float32)
183
- image_array = image_array[:, :, ::-1] # RGB to BGR
184
- return np.expand_dims(image_array, axis=0)
185
-
186
- def predict(
187
- self,
188
- image,
189
- model_repo,
190
- general_thresh,
191
- general_mcut_enabled,
192
- character_thresh,
193
- character_mcut_enabled,
194
- onnx_path=None,
195
- metadata_path=None,
196
- ):
197
- # Modified to accept onnx_path and metadata_path for custom model
198
- self.load_model(model_repo, onnx_path, metadata_path)
199
- # Added check to ensure custom model files are provided
200
- if self.is_custom_model and (onnx_path is None or metadata_path is None):
201
- return "Please upload ONNX model and metadata JSON files.", {}, {}, {}
202
- image_tensor = self.prepare_image(image)
203
- input_name = self.model.get_inputs()[0].name
204
- # Changed to use None for output names to get all outputs, supporting custom model
205
- outputs = self.model.run(None, {input_name: image_tensor})
206
- if self.is_custom_model:
207
- # Added inference logic from inference file for custom model
208
- # Handle case where model might output initial and refined predictions
209
- refined_preds = outputs[1] if len(outputs) == 2 else outputs[0]
210
- ref_logit = refined_preds[0] # Shape (N_tags,)
211
- # Apply sigmoid to convert logits to probabilities (from inference file)
212
- ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
213
- pred_indices = np.where(ref_prob >= general_thresh)[0]
214
- predicted_tags = [self.tag_names[idx] for idx in pred_indices]
215
- sorted_general_strings = ", ".join(predicted_tags)
216
- # Custom model doesn't use category separation, so return empty for rating and character
217
- rating = {}
218
- character_res = {}
219
- general_res = {self.tag_names[idx]: ref_prob[idx] for idx in pred_indices}
220
- else:
221
- # Existing inference logic for Hugging Face models
222
- preds = outputs[0] # Assumes single output tensor
223
- labels = list(zip(self.tag_names, preds[0].astype(float)))
224
- ratings_names = [labels[i] for i in self.rating_indexes]
225
- rating = dict(ratings_names)
226
- general_names = [labels[i] for i in self.general_indexes]
227
- if general_mcut_enabled:
228
- general_probs = np.array([x[1] for x in general_names])
229
- general_thresh = mcut_threshold(general_probs)
230
- general_res = [x for x in general_names if x[1] > general_thresh]
231
- general_res = dict(general_res)
232
- character_names = [labels[i] for i in self.character_indexes]
233
- if character_mcut_enabled:
234
- character_probs = np.array([x[1] for x in character_names])
235
- character_thresh = mcut_threshold(character_probs)
236
- character_thresh = max(0.15, character_thresh)
237
- character_res = [x for x in character_names if x[1] > character_thresh]
238
- character_res = dict(character_res)
239
- sorted_general_strings = sorted(
240
- general_res.items(),
241
- key=lambda x: x[1],
242
- reverse=True,
243
- )
244
- sorted_general_strings = [x[0] for x in sorted_general_strings]
245
- sorted_general_strings = ", ".join(sorted_general_strings).replace("(", r"\(").replace(")", r"\)")
246
- return sorted_general_strings, rating, character_res, general_res
247
-
248
- def main():
249
- args = parse_args()
250
- predictor = Predictor()
251
- # Added "Custom Model" to the dropdown list to support local ONNX model
252
- dropdown_list = [
253
- SWINV2_MODEL_DSV3_REPO,
254
- CONV_MODEL_DSV3_REPO,
255
- VIT_MODEL_DSV3_REPO,
256
- VIT_LARGE_MODEL_DSV3_REPO,
257
- EVA02_LARGE_MODEL_DSV3_REPO,
258
- MOAT_MODEL_DSV2_REPO,
259
- SWIN_MODEL_DSV2_REPO,
260
- CONV_MODEL_DSV2_REPO,
261
- CONV2_MODEL_DSV2_REPO,
262
- VIT_MODEL_DSV2_REPO,
263
- SWINV2_MODEL_IS_DSV1_REPO,
264
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
265
- "Custom Model",
266
- ]
267
- with gr.Blocks(title=TITLE) as demo:
268
- with gr.Column():
269
- gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
270
- gr.Markdown(value=DESCRIPTION)
271
- with gr.Row():
272
- with gr.Column(variant="panel"):
273
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
274
- model_repo = gr.Dropdown(dropdown_list, value=SWINV2_MODEL_DSV3_REPO, label="Model")
275
- # Added file inputs for ONNX model and metadata, hidden by default
276
- with gr.Row(visible=False) as custom_model_inputs:
277
- onnx_file = gr.File(label="ONNX Model File", file_types=[".onnx"])
278
- metadata_file = gr.File(label="Metadata JSON File", file_types=[".json"])
279
- with gr.Row():
280
- general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3)
281
- general_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
282
- with gr.Row():
283
- character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3)
284
- character_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
285
- with gr.Row():
286
- # Updated clear button to include new file inputs
287
- clear = gr.ClearButton(
288
- components=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, onnx_file, metadata_file],
289
- variant="secondary",
290
- size="lg"
291
- )
292
- submit = gr.Button(value="Submit", variant="primary", size="lg")
293
- with gr.Column(variant="panel"):
294
- sorted_general_strings = gr.Textbox(label="Output (string)")
295
- rating = gr.Label(label="Rating")
296
- character_res = gr.Label(label="Output (characters)")
297
- general_res = gr.Label(label="Output (tags)")
298
- clear.add([sorted_general_strings, rating, character_res, general_res])
299
- # Added event listener to show/hide custom model inputs based on model selection
300
- model_repo.change(
301
- lambda x: gr.update(visible=(x == "Custom Model")),
302
- inputs=model_repo,
303
- outputs=custom_model_inputs,
304
- )
305
- # Updated submit event to pass onnx_file and metadata_file to predict
306
- submit.click(
307
- predictor.predict,
308
- inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, onnx_file, metadata_file],
309
- outputs=[sorted_general_strings, rating, character_res, general_res],
310
- )
311
- gr.Examples(
312
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
313
- inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled],
314
- )
315
- demo.queue(max_size=10)
316
- demo.launch()
317
 
318
  if __name__ == "__main__":
319
- main()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
 
 
 
2
  import numpy as np
3
+ import json
 
4
  from PIL import Image
 
5
 
6
+ def preprocess_image(img_path, target_size=512, keep_aspect=True):
7
+ """
8
+ Load an image from img_path, convert to RGB,
9
+ and resize/pad to (target_size, target_size).
10
+ Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.
11
+ """
12
+ img = Image.open(img_path).convert("RGB")
13
+
14
+ if keep_aspect:
15
+ # Preserve aspect ratio, pad black
16
+ w, h = img.size
17
+ aspect = w / h
18
+ if aspect > 1:
19
+ new_w = target_size
20
+ new_h = int(new_w / aspect)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  else:
22
+ new_h = target_size
23
+ new_w = int(new_h * aspect)
24
+
25
+ # Resize with Lanczos
26
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
27
+ # Pad to a square
28
+ background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
29
+ paste_x = (target_size - new_w) // 2
30
+ paste_y = (target_size - new_h) // 2
31
+ background.paste(img, (paste_x, paste_y))
32
+ img = background
33
+ else:
34
+ # simple direct resize to 512x512
35
+ img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
36
+
37
+ # Convert to numpy array
38
+ arr = np.array(img).astype("float32") / 255.0 # scale to [0,1]
39
+ # Transpose from HWC -> CHW
40
+ arr = np.transpose(arr, (2, 0, 1))
41
+ # Add batch dimension: (1,3,512,512)
42
+ arr = np.expand_dims(arr, axis=0)
43
+ return arr
44
+
45
+ def onnx_inference(img_paths,
46
+ onnx_path="camie_refined_no_flash.onnx",
47
+ threshold=0.325,
48
+ metadata_file="metadata.json"):
49
+ """
50
+ Loads the ONNX model, runs inference on a list of image paths,
51
+ and applies an optional threshold to produce final predictions.
52
+
53
+ Args:
54
+ img_paths: List of paths to images.
55
+ onnx_path: Path to the exported ONNX model file.
56
+ threshold: Probability threshold for deciding if a tag is predicted.
57
+ metadata_file: Path to metadata.json that contains idx_to_tag etc.
58
+
59
+ Returns:
60
+ A list of dicts, each containing:
61
+ {
62
+ "initial_logits": np.ndarray of shape (N_tags,),
63
+ "refined_logits": np.ndarray of shape (N_tags,),
64
+ "predicted_tags": list of tag indices that exceeded threshold,
65
+ ...
66
+ }
67
+ one dict per input image.
68
+ """
69
+ # 1) Initialize ONNX runtime session
70
+ session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
71
+ # Optional: for GPU usage, see if "CUDAExecutionProvider" is available
72
+ # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
73
+
74
+ # 2) Pre-load metadata
75
+ with open(metadata_file, "r", encoding="utf-8") as f:
76
+ metadata = json.load(f)
77
+ idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
78
+
79
+ # 3) Preprocess each image into a batch
80
+ batch_tensors = []
81
+ for img_path in img_paths:
82
+ x = preprocess_image(img_path, target_size=512, keep_aspect=True)
83
+ batch_tensors.append(x)
84
+ # Concatenate along the batch dimension => shape (batch_size, 3, 512, 512)
85
+ batch_input = np.concatenate(batch_tensors, axis=0)
86
+
87
+ # 4) Run inference
88
+ input_name = session.get_inputs()[0].name # typically "image"
89
+ outputs = session.run(None, {input_name: batch_input})
90
+ # Typically we get [initial_tags, refined_tags] as output
91
+ initial_preds, refined_preds = outputs # shapes => (batch_size, 70527)
92
+
93
+ # 5) For each image in batch, convert logits to predictions if desired
94
+ batch_results = []
95
+ for i in range(initial_preds.shape[0]):
96
+ # Extract one sample's logits
97
+ init_logit = initial_preds[i, :] # shape (N_tags,)
98
+ ref_logit = refined_preds[i, :] # shape (N_tags,)
99
+
100
+ # Convert to probabilities with sigmoid
101
+ ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
102
+
103
+ # Threshold
104
+ pred_indices = np.where(ref_prob >= threshold)[0]
105
+
106
+ # Build result for this image
107
+ result_dict = {
108
+ "initial_logits": init_logit,
109
+ "refined_logits": ref_logit,
110
+ "predicted_indices": pred_indices,
111
+ "predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name
112
+ }
113
+ batch_results.append(result_dict)
114
+
115
+ return batch_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if __name__ == "__main__":
118
+ # Example usage
119
+ images = ["image1.jpg", "image2.jpg", "image3.jpg"]
120
+ results = onnx_inference(images,
121
+ onnx_path="camie_refined_no_flash.onnx",
122
+ threshold=0.325,
123
+ metadata_file="metadata.json")
124
+
125
+ for i, res in enumerate(results):
126
+ print(f"Image: {images[i]}")
127
+ print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
128
+ print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)")
129
+ print()