Spaces:
Running
on
Zero
Running
on
Zero
Tanut
commited on
Commit
·
c3ae240
1
Parent(s):
51276d0
Testing 2 Stable Diffusion
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import numpy as np
|
|
6 |
import qrcode
|
7 |
from qrcode.constants import ERROR_CORRECT_H
|
8 |
from diffusers import (
|
9 |
-
StableDiffusionPipeline,
|
10 |
StableDiffusionControlNetPipeline,
|
11 |
StableDiffusionControlNetImg2ImgPipeline, # for Hi-Res Fix
|
12 |
ControlNetModel,
|
@@ -16,8 +15,13 @@ from diffusers import (
|
|
16 |
# Quiet matplotlib cache warning on Spaces
|
17 |
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
|
22 |
DTYPE = torch.float16
|
23 |
|
@@ -42,10 +46,9 @@ def normalize_color(c):
|
|
42 |
return s
|
43 |
return "white"
|
44 |
|
45 |
-
def make_qr(url="
|
46 |
"""
|
47 |
-
IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR.
|
48 |
-
(No blur. Pixel-perfect.)
|
49 |
"""
|
50 |
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
|
51 |
qr.add_data(url.strip()); qr.make(fit=True)
|
@@ -69,76 +72,58 @@ def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: fl
|
|
69 |
return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")
|
70 |
|
71 |
# ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
|
72 |
-
|
73 |
-
_CN_TXT2IMG =
|
74 |
-
_CN_IMG2IMG =
|
75 |
|
76 |
def _base_scheduler_for(pipe):
|
77 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
78 |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
|
79 |
)
|
80 |
-
pipe.enable_attention_slicing()
|
|
|
|
|
81 |
return pipe
|
82 |
|
83 |
-
def
|
84 |
-
global
|
85 |
-
if
|
86 |
-
|
87 |
-
|
88 |
-
)
|
89 |
-
_SD = _base_scheduler_for(pipe)
|
90 |
-
return _SD
|
91 |
|
92 |
-
def get_qrmon_txt2img_pipe():
|
93 |
-
|
94 |
-
if _CN_TXT2IMG is None:
|
95 |
-
cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
|
96 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
)
|
100 |
-
_CN_TXT2IMG = _base_scheduler_for(pipe)
|
101 |
-
return _CN_TXT2IMG
|
102 |
|
103 |
-
def get_qrmon_img2img_pipe():
|
104 |
-
|
105 |
-
if _CN_IMG2IMG is None:
|
106 |
-
cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
|
107 |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
)
|
111 |
-
_CN_IMG2IMG = _base_scheduler_for(pipe)
|
112 |
-
return _CN_IMG2IMG
|
113 |
-
|
114 |
-
# ---------- ZeroGPU tasks ----------
|
115 |
-
@spaces.GPU(duration=120)
|
116 |
-
def txt2img(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int):
|
117 |
-
pipe = get_sd_pipe()
|
118 |
-
w, h = snap8(width), snap8(height)
|
119 |
-
if int(seed) < 0:
|
120 |
-
seed = random.randint(0, 2**31 - 1)
|
121 |
-
gen = torch.Generator(device="cuda").manual_seed(int(seed))
|
122 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
123 |
-
gc.collect()
|
124 |
-
with torch.autocast(device_type="cuda", dtype=DTYPE):
|
125 |
-
out = pipe(
|
126 |
-
prompt=str(prompt),
|
127 |
-
negative_prompt=str(negative or ""),
|
128 |
-
num_inference_steps=int(steps),
|
129 |
-
guidance_scale=float(cfg),
|
130 |
-
width=w, height=h,
|
131 |
-
generator=gen,
|
132 |
-
)
|
133 |
-
return out.images[0]
|
134 |
|
135 |
# -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) --------
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
|
143 |
s = snap8(size)
|
144 |
|
@@ -150,19 +135,17 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
|
|
150 |
seed = random.randint(0, 2**31 - 1)
|
151 |
gen = torch.Generator(device="cuda").manual_seed(int(seed))
|
152 |
|
153 |
-
# --- Stage A: txt2img with ControlNet
|
154 |
-
pipe = get_qrmon_txt2img_pipe()
|
155 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
156 |
gc.collect()
|
157 |
-
|
158 |
with torch.autocast(device_type="cuda", dtype=DTYPE):
|
159 |
-
# diffusers ≥ 0.30.x uses `image=` for control image
|
160 |
out = pipe(
|
161 |
prompt=str(style_prompt),
|
162 |
negative_prompt=str(negative or ""),
|
163 |
-
image=qr_img,
|
164 |
-
controlnet_conditioning_scale=float(qr_weight),
|
165 |
-
control_guidance_start=0.0,
|
166 |
control_guidance_end=1.0,
|
167 |
num_inference_steps=int(steps),
|
168 |
guidance_scale=float(cfg),
|
@@ -170,13 +153,14 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
|
|
170 |
generator=gen,
|
171 |
)
|
172 |
lowres = out.images[0]
|
|
|
173 |
|
174 |
# --- Optional Stage B: Hi-Res Fix (img2img with same QR)
|
175 |
final = lowres
|
176 |
if use_hires:
|
177 |
up = max(1.0, min(2.0, float(hires_upscale)))
|
178 |
W = snap8(int(s * up)); H = W
|
179 |
-
pipe2 = get_qrmon_img2img_pipe()
|
180 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
181 |
gc.collect()
|
182 |
with torch.autocast(device_type="cuda", dtype=DTYPE):
|
@@ -199,47 +183,77 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
|
|
199 |
final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
|
200 |
return final, lowres, qr_img
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
# ---------- UI ----------
|
203 |
with gr.Blocks() as demo:
|
204 |
-
gr.Markdown("# ZeroGPU •
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
gr.
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
gr.
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
)
|
244 |
|
245 |
if __name__ == "__main__":
|
|
|
6 |
import qrcode
|
7 |
from qrcode.constants import ERROR_CORRECT_H
|
8 |
from diffusers import (
|
|
|
9 |
StableDiffusionControlNetPipeline,
|
10 |
StableDiffusionControlNetImg2ImgPipeline, # for Hi-Res Fix
|
11 |
ControlNetModel,
|
|
|
15 |
# Quiet matplotlib cache warning on Spaces
|
16 |
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
|
17 |
|
18 |
+
# ---- base models for the two tabs ----
|
19 |
+
BASE_MODELS = {
|
20 |
+
"anything": "andite/anything-v4.5",
|
21 |
+
"dream": "Lykon/dreamshaper-8",
|
22 |
+
}
|
23 |
+
|
24 |
+
# ControlNet (QR Monster v2 for SD15)
|
25 |
CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
|
26 |
DTYPE = torch.float16
|
27 |
|
|
|
46 |
return s
|
47 |
return "white"
|
48 |
|
49 |
+
def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
|
50 |
"""
|
51 |
+
IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR (no blur).
|
|
|
52 |
"""
|
53 |
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
|
54 |
qr.add_data(url.strip()); qr.make(fit=True)
|
|
|
72 |
return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")
|
73 |
|
74 |
# ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
|
75 |
+
_CN = None # shared ControlNet QR Monster
|
76 |
+
_CN_TXT2IMG = {} # per-base-model txt2img pipes
|
77 |
+
_CN_IMG2IMG = {} # per-base-model img2img pipes
|
78 |
|
79 |
def _base_scheduler_for(pipe):
|
80 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
81 |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
|
82 |
)
|
83 |
+
pipe.enable_attention_slicing()
|
84 |
+
pipe.enable_vae_slicing()
|
85 |
+
pipe.enable_model_cpu_offload()
|
86 |
return pipe
|
87 |
|
88 |
+
def get_cn():
|
89 |
+
global _CN
|
90 |
+
if _CN is None:
|
91 |
+
_CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
|
92 |
+
return _CN
|
|
|
|
|
|
|
93 |
|
94 |
+
def get_qrmon_txt2img_pipe(model_id: str):
|
95 |
+
if model_id not in _CN_TXT2IMG:
|
|
|
|
|
96 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
97 |
+
model_id,
|
98 |
+
controlnet=get_cn(),
|
99 |
+
torch_dtype=DTYPE,
|
100 |
+
safety_checker=None,
|
101 |
+
use_safetensors=True,
|
102 |
+
low_cpu_mem_usage=True,
|
103 |
)
|
104 |
+
_CN_TXT2IMG[model_id] = _base_scheduler_for(pipe)
|
105 |
+
return _CN_TXT2IMG[model_id]
|
106 |
|
107 |
+
def get_qrmon_img2img_pipe(model_id: str):
|
108 |
+
if model_id not in _CN_IMG2IMG:
|
|
|
|
|
109 |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
110 |
+
model_id,
|
111 |
+
controlnet=get_cn(),
|
112 |
+
torch_dtype=DTYPE,
|
113 |
+
safety_checker=None,
|
114 |
+
use_safetensors=True,
|
115 |
+
low_cpu_mem_usage=True,
|
116 |
)
|
117 |
+
_CN_IMG2IMG[model_id] = _base_scheduler_for(pipe)
|
118 |
+
return _CN_IMG2IMG[model_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
# -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) --------
|
121 |
+
def _qr_txt2img_core(model_id: str,
|
122 |
+
url: str, style_prompt: str, negative: str,
|
123 |
+
steps: int, cfg: float, size: int, border: int,
|
124 |
+
qr_weight: float, seed: int,
|
125 |
+
use_hires: bool, hires_upscale: float, hires_strength: float,
|
126 |
+
repair_strength: float, feather: float):
|
127 |
|
128 |
s = snap8(size)
|
129 |
|
|
|
135 |
seed = random.randint(0, 2**31 - 1)
|
136 |
gen = torch.Generator(device="cuda").manual_seed(int(seed))
|
137 |
|
138 |
+
# --- Stage A: txt2img with ControlNet
|
139 |
+
pipe = get_qrmon_txt2img_pipe(model_id)
|
140 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
141 |
gc.collect()
|
|
|
142 |
with torch.autocast(device_type="cuda", dtype=DTYPE):
|
|
|
143 |
out = pipe(
|
144 |
prompt=str(style_prompt),
|
145 |
negative_prompt=str(negative or ""),
|
146 |
+
image=qr_img, # control image for txt2img
|
147 |
+
controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well
|
148 |
+
control_guidance_start=0.0,
|
149 |
control_guidance_end=1.0,
|
150 |
num_inference_steps=int(steps),
|
151 |
guidance_scale=float(cfg),
|
|
|
153 |
generator=gen,
|
154 |
)
|
155 |
lowres = out.images[0]
|
156 |
+
lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather))
|
157 |
|
158 |
# --- Optional Stage B: Hi-Res Fix (img2img with same QR)
|
159 |
final = lowres
|
160 |
if use_hires:
|
161 |
up = max(1.0, min(2.0, float(hires_upscale)))
|
162 |
W = snap8(int(s * up)); H = W
|
163 |
+
pipe2 = get_qrmon_img2img_pipe(model_id)
|
164 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
165 |
gc.collect()
|
166 |
with torch.autocast(device_type="cuda", dtype=DTYPE):
|
|
|
183 |
final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
|
184 |
return final, lowres, qr_img
|
185 |
|
186 |
+
# Wrappers for each tab (so Gradio can bind without passing the model id)
|
187 |
+
@spaces.GPU(duration=120)
|
188 |
+
def qr_txt2img_anything(*args):
|
189 |
+
return _qr_txt2img_core(BASE_MODELS["anything"], *args)
|
190 |
+
|
191 |
+
@spaces.GPU(duration=120)
|
192 |
+
def qr_txt2img_dream(*args):
|
193 |
+
return _qr_txt2img_core(BASE_MODELS["dream"], *args)
|
194 |
+
|
195 |
# ---------- UI ----------
|
196 |
with gr.Blocks() as demo:
|
197 |
+
gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)")
|
198 |
+
|
199 |
+
# ---- Tab 1: Anything v4.5 (anime/illustration) ----
|
200 |
+
with gr.Tab("Method 1 • Anything v4.5"):
|
201 |
+
url1 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
|
202 |
+
s_prompt1 = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition")
|
203 |
+
s_negative1= gr.Textbox(label="Negative prompt", value="ugly, low quality, blurry, nsfw, watermark, text, low contrast, deformed, extra digits")
|
204 |
+
size1 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
|
205 |
+
steps1 = gr.Slider(10, 50, value=20, step=1, label="Steps")
|
206 |
+
cfg1 = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG")
|
207 |
+
border1 = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
|
208 |
+
qr_w1 = gr.Slider(0.6, 1.6, value=1.1, step=0.05, label="QR control weight")
|
209 |
+
seed1 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
|
210 |
+
|
211 |
+
use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
|
212 |
+
hires_up1 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
|
213 |
+
hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
|
214 |
+
|
215 |
+
repair1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
|
216 |
+
feather1 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
|
217 |
+
|
218 |
+
final_img1 = gr.Image(label="Final (or Hi-Res) image")
|
219 |
+
low_img1 = gr.Image(label="Low-res (Stage A) preview")
|
220 |
+
ctrl_img1 = gr.Image(label="Control QR used")
|
221 |
+
|
222 |
+
gr.Button("Generate with Anything v4.5").click(
|
223 |
+
qr_txt2img_anything,
|
224 |
+
[url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1,
|
225 |
+
use_hires1, hires_up1, hires_str1, repair1, feather1],
|
226 |
+
[final_img1, low_img1, ctrl_img1]
|
227 |
+
)
|
228 |
+
|
229 |
+
# ---- Tab 2: DreamShaper (general art/painterly) ----
|
230 |
+
with gr.Tab("Method 1 • DreamShaper 8"):
|
231 |
+
url2 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
|
232 |
+
s_prompt2 = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic")
|
233 |
+
s_negative2= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, watermark, text, bad anatomy")
|
234 |
+
size2 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
|
235 |
+
steps2 = gr.Slider(10, 50, value=24, step=1, label="Steps")
|
236 |
+
cfg2 = gr.Slider(1.0, 12.0, value=6.8, step=0.1, label="CFG")
|
237 |
+
border2 = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)")
|
238 |
+
qr_w2 = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight")
|
239 |
+
seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
|
240 |
+
|
241 |
+
use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
|
242 |
+
hires_up2 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
|
243 |
+
hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
|
244 |
+
|
245 |
+
repair2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
|
246 |
+
feather2 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
|
247 |
+
|
248 |
+
final_img2 = gr.Image(label="Final (or Hi-Res) image")
|
249 |
+
low_img2 = gr.Image(label="Low-res (Stage A) preview")
|
250 |
+
ctrl_img2 = gr.Image(label="Control QR used")
|
251 |
+
|
252 |
+
gr.Button("Generate with DreamShaper 8").click(
|
253 |
+
qr_txt2img_dream,
|
254 |
+
[url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2,
|
255 |
+
use_hires2, hires_up2, hires_str2, repair2, feather2],
|
256 |
+
[final_img2, low_img2, ctrl_img2]
|
257 |
)
|
258 |
|
259 |
if __name__ == "__main__":
|