ehristoforu mkshing commited on
Commit
c781418
·
0 Parent(s):

Duplicate from rinna/japanese-stable-diffusion

Browse files

Co-authored-by: mkshing <[email protected]>

Files changed (5) hide show
  1. .gitattributes +33 -0
  2. README.md +14 -0
  3. app.py +286 -0
  4. nsfw.png +3 -0
  5. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ nfsw.png filter=lfs diff=lfs merge=lfs -text
33
+ nsfw.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Japanese Stable Diffusion
3
+ emoji: 🎨
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ duplicated_from: rinna/japanese-stable-diffusion
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from torch import autocast
6
+ from diffusers import LMSDiscreteScheduler
7
+ from japanese_stable_diffusion import JapaneseStableDiffusionPipeline
8
+ from PIL import Image
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ load_dotenv()
13
+ ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
14
+
15
+
16
+ model_id = "rinna/japanese-stable-diffusion"
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
20
+ pipe = JapaneseStableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, use_auth_token=ACCESS_TOKEN)
21
+ pipe.to(device)
22
+ pipe.unet.half()
23
+ pipe.text_encoder.half()
24
+ #torch.backends.cudnn.benchmark = True
25
+
26
+
27
+ def infer(
28
+ prompt,
29
+ n_samples=4,
30
+ guidance_scale=7.5,
31
+ steps=50,
32
+ seed="random",
33
+ ):
34
+ if seed == "random":
35
+ generator = torch.Generator(device=device).manual_seed(int(random.randint(0, 2 ** 32)))
36
+ else:
37
+ generator = torch.Generator(device=device).manual_seed(int(seed))
38
+
39
+ with autocast("cuda"):
40
+ images_list = pipe(
41
+ prompt=[prompt] * int(n_samples),
42
+ guidance_scale=guidance_scale,
43
+ num_inference_steps=int(steps),
44
+ generator=generator
45
+ )
46
+ images = []
47
+ safe_image = Image.open(r"nsfw.png")
48
+ for i, image in enumerate(images_list.images):
49
+ if (images_list["nsfw_content_detected"][i]):
50
+ images.append(safe_image)
51
+ else:
52
+ images.append(image)
53
+ return images
54
+
55
+
56
+ css = """
57
+ .gradio-container {
58
+ font-family: 'IBM Plex Sans', sans-serif;
59
+ }
60
+ .gr-button {
61
+ color: white;
62
+ border-color: black;
63
+ background: black;
64
+ }
65
+ input[type='range'] {
66
+ accent-color: black;
67
+ }
68
+ .dark input[type='range'] {
69
+ accent-color: #dfdfdf;
70
+ }
71
+ .container {
72
+ max-width: 730px;
73
+ margin: auto;
74
+ padding-top: 1.5rem;
75
+ }
76
+ #gallery {
77
+ min-height: 22rem;
78
+ margin-bottom: 15px;
79
+ margin-left: auto;
80
+ margin-right: auto;
81
+ border-bottom-right-radius: .5rem !important;
82
+ border-bottom-left-radius: .5rem !important;
83
+ }
84
+ #gallery>div>.h-full {
85
+ min-height: 20rem;
86
+ }
87
+ .details:hover {
88
+ text-decoration: underline;
89
+ }
90
+ .gr-button {
91
+ white-space: nowrap;
92
+ }
93
+ .gr-button:focus {
94
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
95
+ outline: none;
96
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
97
+ --tw-border-opacity: 1;
98
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
99
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
100
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
101
+ --tw-ring-opacity: .5;
102
+ }
103
+ #advanced-btn {
104
+ font-size: .7rem !important;
105
+ line-height: 19px;
106
+ margin-top: 12px;
107
+ margin-bottom: 12px;
108
+ padding: 2px 8px;
109
+ border-radius: 14px !important;
110
+ }
111
+ #advanced-options {
112
+ display: none;
113
+ margin-bottom: 20px;
114
+ }
115
+ .footer {
116
+ margin-bottom: 45px;
117
+ margin-top: 35px;
118
+ text-align: center;
119
+ border-bottom: 1px solid #e5e5e5;
120
+ }
121
+ .footer>p {
122
+ font-size: .8rem;
123
+ display: inline-block;
124
+ padding: 0 10px;
125
+ transform: translateY(10px);
126
+ background: white;
127
+ }
128
+ .dark .footer {
129
+ border-color: #303030;
130
+ }
131
+ .dark .footer>p {
132
+ background: #0b0f19;
133
+ }
134
+ .acknowledgments h4{
135
+ margin: 1.25em 0 .25em 0;
136
+ font-weight: bold;
137
+ font-size: 115%;
138
+ }
139
+ """
140
+
141
+ block = gr.Blocks(css=css)
142
+
143
+
144
+ examples = [
145
+ ["サラリーマン 油絵", 2, 7.5, 50, "random"],
146
+ ["キラキラ瞳の猫", 2, 7.5, 50, "random"],
147
+ ["夕暮れの神社の夏祭りを描いた水彩画", 2, 7.5, 50, "random"]
148
+ ]
149
+
150
+ with block:
151
+ gr.HTML(
152
+ """
153
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
154
+ <div
155
+ style="
156
+ display: inline-flex;
157
+ align-items: center;
158
+ gap: 0.8rem;
159
+ font-size: 1.75rem;
160
+ "
161
+ >
162
+ <svg
163
+ width="0.65em"
164
+ height="0.65em"
165
+ viewBox="0 0 115 115"
166
+ fill="none"
167
+ xmlns="http://www.w3.org/2000/svg"
168
+ >
169
+ <rect width="23" height="23" fill="white"></rect>
170
+ <rect y="69" width="23" height="23" fill="white"></rect>
171
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
172
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
173
+ <rect x="46" width="23" height="23" fill="white"></rect>
174
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
175
+ <rect x="69" width="23" height="23" fill="black"></rect>
176
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
177
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
178
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
179
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
180
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
181
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
182
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
183
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
184
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
185
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
186
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
187
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
188
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
189
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
190
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
191
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
192
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
193
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
194
+ </svg>
195
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
196
+ Japanese Stable Diffusion Demo
197
+ </h1>
198
+ </div>
199
+ <p style="margin-bottom: 10px; font-size: 94%">
200
+ <a
201
+ href="https://github.com/rinnakk/japanese-stable-diffusion/"
202
+ style="text-decoration: underline;"
203
+ target="_blank"
204
+ >Japanese Stable Diffusion</a
205
+ >
206
+ is a Japanese-language specific latent text-to-image diffusion model.
207
+ </p>
208
+ </div>
209
+ """
210
+ )
211
+ with gr.Group():
212
+ with gr.Box():
213
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
214
+ text = gr.Textbox(
215
+ label="Enter your prompt in Japanese",
216
+ show_label=False,
217
+ max_lines=1,
218
+ placeholder="Enter your prompt in Japanese",
219
+ ).style(
220
+ border=(True, False, True, True),
221
+ rounded=(True, False, False, True),
222
+ container=False,
223
+ )
224
+ btn = gr.Button("Generate image").style(
225
+ margin=False,
226
+ rounded=(False, True, True, False),
227
+ )
228
+
229
+ gallery = gr.Gallery(
230
+ label="Generated images", show_label=False, elem_id="gallery"
231
+ ).style(grid=[2], height="auto")
232
+
233
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
234
+
235
+ with gr.Row(elem_id="advanced-options"):
236
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
237
+ steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=50, step=1)
238
+ scale = gr.Slider(
239
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
240
+ )
241
+ seed = gr.Textbox(value='random',
242
+ placeholder="If you fix seed, you get same outputs all the time. You can set as integer like 42.",
243
+ label="seed")
244
+
245
+ ex = gr.Examples(
246
+ examples=examples, fn=infer, inputs=[text, samples, scale, steps, seed], outputs=gallery, cache_examples=True
247
+ )
248
+ ex.dataset.headers = [""]
249
+
250
+ text.submit(infer, inputs=[text, samples, scale, steps, seed], outputs=gallery)
251
+ btn.click(infer, inputs=[text, samples, scale, steps, seed], outputs=gallery)
252
+ advanced_button.click(
253
+ None,
254
+ [],
255
+ text,
256
+ _js="""
257
+ () => {
258
+ const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
259
+ options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
260
+ }""",
261
+ )
262
+ gr.HTML(
263
+ """
264
+ <div class="footer">
265
+ <p>Model by <a href="https://huggingface.co/rinna" style="text-decoration: underline;" target="_blank">rinna</a> - Gradio Demo by 🤗 Hugging Face
266
+ </p>
267
+ </div>
268
+ <div class="acknowledgments">
269
+ <p><h4>LICENSE</h4>
270
+ The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a>.</p>
271
+ <p><h4>Limitations and Bias</h4>
272
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. Japanese Stable Diffusion was trained on Japanese datasets including LAION-5B with Japanese captions, which consists of images that are primarily limited to Japanese descriptions. Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. This affects the overall output of the model. Further, the ability of the model to generate content with non-Japanese prompts is significantly worse than with Japanese-language prompts. You can read more in the <a href="https://huggingface.co/rinna/japanese-stable-diffusion#limitations-and-bias" style="text-decoration: underline;" target="_blank">model card</a>.</p>
273
+ </div>
274
+ <br> 
275
+ <br>
276
+ <i>This demo is based on the
277
+ <a
278
+ href="https://huggingface.co/spaces/stabilityai/stable-diffusion/"
279
+ style="text-decoration: underline;"
280
+ target="_blank"
281
+ >Stable Diffusion Demo</a
282
+ >.</i>
283
+ """
284
+ )
285
+
286
+ block.queue(max_size=25).launch()
nsfw.png ADDED

Git LFS Details

  • SHA256: aa83b7895912c507e3e48a2fc8aac0c4d691ccf39eacba5173ff1c6d6de91d72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.09 MB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/rinnakk/japanese-stable-diffusion.git@master#egg=japanese_stable_diffusion
2
+ accelerate
3
+ transformers
4
+ nvidia-ml-py3
5
+ ftfy
6
+ --extra-index-url https://download.pytorch.org/whl/cu113
7
+ torch
8
+ python-dotenv