Spaces:
Running
Running
hysts
commited on
Commit
·
72d9403
1
Parent(s):
7be8f37
Return a list of predicted labels
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import functools
|
|
| 7 |
import os
|
| 8 |
import pathlib
|
| 9 |
import tarfile
|
|
|
|
| 10 |
|
| 11 |
import deepdanbooru as dd
|
| 12 |
import gradio as gr
|
|
@@ -64,7 +65,8 @@ def load_labels() -> list[str]:
|
|
| 64 |
|
| 65 |
|
| 66 |
def predict(image: PIL.Image.Image, score_threshold: float,
|
| 67 |
-
model: tf.keras.Model,
|
|
|
|
| 68 |
_, height, width, _ = model.input_shape
|
| 69 |
image = np.asarray(image)
|
| 70 |
image = tf.image.resize(image,
|
|
@@ -81,7 +83,14 @@ def predict(image: PIL.Image.Image, score_threshold: float,
|
|
| 81 |
if prob < score_threshold:
|
| 82 |
continue
|
| 83 |
res[label] = prob
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
def main():
|
|
@@ -106,7 +115,10 @@ def main():
|
|
| 106 |
value=args.score_threshold,
|
| 107 |
label='Score Threshold'),
|
| 108 |
],
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
examples=examples,
|
| 111 |
title=TITLE,
|
| 112 |
description=DESCRIPTION,
|
|
|
|
| 7 |
import os
|
| 8 |
import pathlib
|
| 9 |
import tarfile
|
| 10 |
+
import tempfile
|
| 11 |
|
| 12 |
import deepdanbooru as dd
|
| 13 |
import gradio as gr
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def predict(image: PIL.Image.Image, score_threshold: float,
|
| 68 |
+
model: tf.keras.Model,
|
| 69 |
+
labels: list[str]) -> tuple[dict[str, float], str]:
|
| 70 |
_, height, width, _ = model.input_shape
|
| 71 |
image = np.asarray(image)
|
| 72 |
image = tf.image.resize(image,
|
|
|
|
| 83 |
if prob < score_threshold:
|
| 84 |
continue
|
| 85 |
res[label] = prob
|
| 86 |
+
|
| 87 |
+
sorted_preds = sorted(res.items(), key=lambda x: -x[1])
|
| 88 |
+
out_path = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
| 89 |
+
with open(out_path.name, 'w') as f:
|
| 90 |
+
for key, _ in sorted_preds:
|
| 91 |
+
f.write(f'{key}\n')
|
| 92 |
+
|
| 93 |
+
return res, out_path.name
|
| 94 |
|
| 95 |
|
| 96 |
def main():
|
|
|
|
| 115 |
value=args.score_threshold,
|
| 116 |
label='Score Threshold'),
|
| 117 |
],
|
| 118 |
+
[
|
| 119 |
+
gr.Label(label='Output'),
|
| 120 |
+
gr.File(label='Tag List'),
|
| 121 |
+
],
|
| 122 |
examples=examples,
|
| 123 |
title=TITLE,
|
| 124 |
description=DESCRIPTION,
|