PICOF commited on
Commit
e92360d
·
1 Parent(s): 39efbfa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import html
9
+ import pathlib
10
+ import tarfile
11
+
12
+ import deepdanbooru as dd
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import PIL.Image
17
+ import tensorflow as tf
18
+ import piexif
19
+ import piexif.helper
20
+
21
+ TITLE = 'Alchemy'
22
+
23
+ TOKEN = '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'
24
+ MODEL_REPO = 'PICOF/deepdanbooru'
25
+ MODEL_FILENAME = 'model-resnet_custom_v3.h5'
26
+ LABEL_FILENAME = 'tags.txt'
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
32
+ parser.add_argument('--score-threshold', type=float, default=0.5)
33
+ parser.add_argument('--theme', type=str, default='dark-grass')
34
+ # parser.add_argument('--live', action='store_true')
35
+ parser.add_argument('--share', action='store_true')
36
+ # parser.add_argument('--port', type=int)
37
+ # parser.add_argument('--disable-queue',
38
+ # dest='enable_queue',
39
+ # action='store_false')
40
+ parser.add_argument('--allow-flagging', type=str, default='never')
41
+ return parser.parse_args()
42
+
43
+ #def load_sample_image_paths() -> list[pathlib.Path]:
44
+ # image_dir = pathlib.Path('images')
45
+ # if not image_dir.exists():
46
+ # dataset_repo = 'hysts/sample-images-TADNE'
47
+ # path = huggingface_hub.hf_hub_download(dataset_repo,
48
+ # 'images.tar.gz',
49
+ # repo_type='dataset',
50
+ # use_auth_token=TOKEN)
51
+ # with tarfile.open(path) as f:
52
+ # f.extractall()
53
+ # return sorted(image_dir.glob('*'))
54
+
55
+
56
+ def load_model() -> tf.keras.Model:
57
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
58
+ MODEL_FILENAME,
59
+ use_auth_token=TOKEN)
60
+ model = tf.keras.models.load_model(path)
61
+ return model
62
+
63
+
64
+ def load_labels() -> list[str]:
65
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
66
+ LABEL_FILENAME,
67
+ use_auth_token=TOKEN)
68
+ with open(path) as f:
69
+ labels = [line.strip() for line in f.readlines()]
70
+ return labels
71
+
72
+ def plaintext_to_html(text):
73
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
74
+ return text
75
+
76
+ def predict(image: PIL.Image.Image, score_threshold: float,
77
+ model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
78
+ rawimage = image
79
+ _, height, width, _ = model.input_shape
80
+ image = np.asarray(image)
81
+ image = tf.image.resize(image,
82
+ size=(height, width),
83
+ method=tf.image.ResizeMethod.AREA,
84
+ preserve_aspect_ratio=True)
85
+ image = image.numpy()
86
+ image = dd.image.transform_and_pad_image(image, width, height)
87
+ image = image / 255.
88
+ probs = model.predict(image[None, ...])[0]
89
+ probs = probs.astype(float)
90
+ res = dict()
91
+ for prob, label in zip(probs.tolist(), labels):
92
+ if prob < score_threshold:
93
+ continue
94
+ res[label] = prob
95
+ b = dict(sorted(res.items(),key=lambda item:item[1], reverse=True))
96
+ a = ', '.join(list(b.keys())).replace('_',' ').replace('(','\(').replace(')','\)')
97
+
98
+ items = rawimage.info
99
+ geninfo = ''
100
+
101
+ if "exif" in rawimage.info:
102
+ exif = piexif.load(rawimage.info["exif"])
103
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
104
+ try:
105
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
106
+ except ValueError:
107
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
108
+
109
+ items['exif comment'] = exif_comment
110
+ geninfo = exif_comment
111
+
112
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
113
+ 'loop', 'background', 'timestamp', 'duration']:
114
+ items.pop(field, None)
115
+
116
+ geninfo = items.get('parameters', geninfo)
117
+
118
+ info = f"""
119
+ <p><h4>PNG Info</h4></p>
120
+ """
121
+ for key, text in items.items():
122
+ info += f"""
123
+ <div>
124
+ {str(key)}:[{str(text)}]
125
+ </div>
126
+ """.strip()+"\n"
127
+
128
+ if len(info) == 0:
129
+ message = "Nothing found in the image."
130
+ info = f"<div><p>{message}<p></div>"
131
+
132
+ return (info,a,res)
133
+
134
+
135
+ def main():
136
+ args = parse_args()
137
+ model = load_model()
138
+ labels = load_labels()
139
+
140
+ func = functools.partial(predict, model=model, labels=labels)
141
+ func = functools.update_wrapper(func, predict)
142
+
143
+ gr.Interface(
144
+ func,
145
+ [
146
+ gr.inputs.Image(type='pil', label='Input'),
147
+ gr.inputs.Slider(0,
148
+ 1,
149
+ step=args.score_slider_step,
150
+ default=args.score_threshold,
151
+ label='Score Threshold'),
152
+ ],
153
+ [
154
+ gr.outputs.HTML(),
155
+ gr.outputs.Textbox(label='DeepDanbooru Output (string)'),
156
+ gr.outputs.Label(label='DeepDanbooru Output (label)')
157
+ ],
158
+ examples=[
159
+ ['eula.png',0.5],
160
+ ['keqing.png',0.5]
161
+ ],
162
+ title=TITLE,
163
+ description='''
164
+ Just play for fun...
165
+ ''',
166
+ theme=args.theme,
167
+ allow_flagging=args.allow_flagging,
168
+ ).launch(
169
+ enable_queue=True,
170
+ share=args.share,
171
+ )
172
+
173
+
174
+ if __name__ == '__main__':
175
+ main()