Tanut commited on
Commit
c3ae240
·
1 Parent(s): 51276d0

Testing 2 Stable Diffusion

Browse files
Files changed (1) hide show
  1. app.py +122 -108
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import qrcode
7
  from qrcode.constants import ERROR_CORRECT_H
8
  from diffusers import (
9
- StableDiffusionPipeline,
10
  StableDiffusionControlNetPipeline,
11
  StableDiffusionControlNetImg2ImgPipeline, # for Hi-Res Fix
12
  ControlNetModel,
@@ -16,8 +15,13 @@ from diffusers import (
16
  # Quiet matplotlib cache warning on Spaces
17
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
18
 
19
- MODEL_ID = "runwayml/stable-diffusion-v1-5"
20
- # You can swap to a QR-Pattern-v2 repo if you know one on HF.
 
 
 
 
 
21
  CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
22
  DTYPE = torch.float16
23
 
@@ -42,10 +46,9 @@ def normalize_color(c):
42
  return s
43
  return "white"
44
 
45
- def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
46
  """
47
- IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR.
48
- (No blur. Pixel-perfect.)
49
  """
50
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
51
  qr.add_data(url.strip()); qr.make(fit=True)
@@ -69,76 +72,58 @@ def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: fl
69
  return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")
70
 
71
  # ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
72
- _SD = None
73
- _CN_TXT2IMG = None
74
- _CN_IMG2IMG = None
75
 
76
  def _base_scheduler_for(pipe):
77
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(
78
  pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
79
  )
80
- pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
 
 
81
  return pipe
82
 
83
- def get_sd_pipe():
84
- global _SD
85
- if _SD is None:
86
- pipe = StableDiffusionPipeline.from_pretrained(
87
- MODEL_ID, torch_dtype=DTYPE, safety_checker=None, use_safetensors=True, low_cpu_mem_usage=True
88
- )
89
- _SD = _base_scheduler_for(pipe)
90
- return _SD
91
 
92
- def get_qrmon_txt2img_pipe():
93
- global _CN_TXT2IMG
94
- if _CN_TXT2IMG is None:
95
- cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
96
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
97
- MODEL_ID, controlnet=cn, torch_dtype=DTYPE, safety_checker=None,
98
- use_safetensors=True, low_cpu_mem_usage=True
 
 
 
 
99
  )
100
- _CN_TXT2IMG = _base_scheduler_for(pipe)
101
- return _CN_TXT2IMG
102
 
103
- def get_qrmon_img2img_pipe():
104
- global _CN_IMG2IMG
105
- if _CN_IMG2IMG is None:
106
- cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
107
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
108
- MODEL_ID, controlnet=cn, torch_dtype=DTYPE, safety_checker=None,
109
- use_safetensors=True, low_cpu_mem_usage=True
 
 
 
 
110
  )
111
- _CN_IMG2IMG = _base_scheduler_for(pipe)
112
- return _CN_IMG2IMG
113
-
114
- # ---------- ZeroGPU tasks ----------
115
- @spaces.GPU(duration=120)
116
- def txt2img(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int):
117
- pipe = get_sd_pipe()
118
- w, h = snap8(width), snap8(height)
119
- if int(seed) < 0:
120
- seed = random.randint(0, 2**31 - 1)
121
- gen = torch.Generator(device="cuda").manual_seed(int(seed))
122
- if torch.cuda.is_available(): torch.cuda.empty_cache()
123
- gc.collect()
124
- with torch.autocast(device_type="cuda", dtype=DTYPE):
125
- out = pipe(
126
- prompt=str(prompt),
127
- negative_prompt=str(negative or ""),
128
- num_inference_steps=int(steps),
129
- guidance_scale=float(cfg),
130
- width=w, height=h,
131
- generator=gen,
132
- )
133
- return out.images[0]
134
 
135
  # -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) --------
136
- @spaces.GPU(duration=120)
137
- def qr_txt2img(url: str, style_prompt: str, negative: str,
138
- steps: int, cfg: float, size: int, border: int,
139
- qr_weight: float, seed: int,
140
- use_hires: bool, hires_upscale: float, hires_strength: float,
141
- repair_strength: float, feather: float):
142
 
143
  s = snap8(size)
144
 
@@ -150,19 +135,17 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
150
  seed = random.randint(0, 2**31 - 1)
151
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
152
 
153
- # --- Stage A: txt2img with ControlNet (the actual "Method 1")
154
- pipe = get_qrmon_txt2img_pipe()
155
  if torch.cuda.is_available(): torch.cuda.empty_cache()
156
  gc.collect()
157
-
158
  with torch.autocast(device_type="cuda", dtype=DTYPE):
159
- # diffusers ≥ 0.30.x uses `image=` for control image
160
  out = pipe(
161
  prompt=str(style_prompt),
162
  negative_prompt=str(negative or ""),
163
- image=qr_img,
164
- controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well
165
- control_guidance_start=0.0, # "Balanced" feel
166
  control_guidance_end=1.0,
167
  num_inference_steps=int(steps),
168
  guidance_scale=float(cfg),
@@ -170,13 +153,14 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
170
  generator=gen,
171
  )
172
  lowres = out.images[0]
 
173
 
174
  # --- Optional Stage B: Hi-Res Fix (img2img with same QR)
175
  final = lowres
176
  if use_hires:
177
  up = max(1.0, min(2.0, float(hires_upscale)))
178
  W = snap8(int(s * up)); H = W
179
- pipe2 = get_qrmon_img2img_pipe()
180
  if torch.cuda.is_available(): torch.cuda.empty_cache()
181
  gc.collect()
182
  with torch.autocast(device_type="cuda", dtype=DTYPE):
@@ -199,47 +183,77 @@ def qr_txt2img(url: str, style_prompt: str, negative: str,
199
  final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
200
  return final, lowres, qr_img
201
 
 
 
 
 
 
 
 
 
 
202
  # ---------- UI ----------
203
  with gr.Blocks() as demo:
204
- gr.Markdown("# ZeroGPU • SD1.5 + AI QR (Method 1)")
205
-
206
- with gr.Tab("Plain Text Image"):
207
- prompt = gr.Textbox(label="Prompt", value="Japanese painting, mountains")
208
- negative = gr.Textbox(label="Negative (optional)", value="ugly, disfigured, low quality, blurry, nsfw")
209
- steps = gr.Slider(8, 40, value=20, step=1, label="Steps")
210
- cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG")
211
- width = gr.Slider(256, 1024, value=512, step=16, label="Width")
212
- height = gr.Slider(256, 1024, value=512, step=16, label="Height")
213
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
214
- out_img = gr.Image(label="Image", interactive=False)
215
- gr.Button("Generate").click(txt2img, [prompt, negative, steps, cfg, width, height, seed], out_img)
216
-
217
- with gr.Tab("Method 1: QR control (txt2img)"):
218
- url = gr.Textbox(label="URL/Text", value="https://example.com")
219
- s_prompt = gr.Textbox(label="Style prompt", value="Japanese painting, mountains, 1girl")
220
- s_negative= gr.Textbox(label="Negative prompt", value="ugly, disfigured, low quality, blurry, nsfw")
221
- size = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
222
- steps2 = gr.Slider(10, 50, value=20, step=1, label="Steps")
223
- cfg2 = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG")
224
- border = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
225
- qr_w = gr.Slider(0.6, 1.6, value=1.1, step=0.05, label="QR control weight")
226
- seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
227
-
228
- use_hires = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
229
- hires_up = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
230
- hires_str = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
231
-
232
- repair = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
233
- feather = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
234
-
235
- final_img = gr.Image(label="Final (or Hi-Res) image")
236
- low_img = gr.Image(label="Low-res (Stage A) preview")
237
- ctrl_img = gr.Image(label="Control QR used")
238
-
239
- gr.Button("Generate QR Art").click(
240
- qr_txt2img,
241
- [url, s_prompt, s_negative, steps2, cfg2, size, border, qr_w, seed2, use_hires, hires_up, hires_str, repair, feather],
242
- [final_img, low_img, ctrl_img]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
244
 
245
  if __name__ == "__main__":
 
6
  import qrcode
7
  from qrcode.constants import ERROR_CORRECT_H
8
  from diffusers import (
 
9
  StableDiffusionControlNetPipeline,
10
  StableDiffusionControlNetImg2ImgPipeline, # for Hi-Res Fix
11
  ControlNetModel,
 
15
  # Quiet matplotlib cache warning on Spaces
16
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
17
 
18
+ # ---- base models for the two tabs ----
19
+ BASE_MODELS = {
20
+ "anything": "andite/anything-v4.5",
21
+ "dream": "Lykon/dreamshaper-8",
22
+ }
23
+
24
+ # ControlNet (QR Monster v2 for SD15)
25
  CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
26
  DTYPE = torch.float16
27
 
 
46
  return s
47
  return "white"
48
 
49
+ def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
50
  """
51
+ IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR (no blur).
 
52
  """
53
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
54
  qr.add_data(url.strip()); qr.make(fit=True)
 
72
  return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")
73
 
74
  # ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
75
+ _CN = None # shared ControlNet QR Monster
76
+ _CN_TXT2IMG = {} # per-base-model txt2img pipes
77
+ _CN_IMG2IMG = {} # per-base-model img2img pipes
78
 
79
  def _base_scheduler_for(pipe):
80
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(
81
  pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
82
  )
83
+ pipe.enable_attention_slicing()
84
+ pipe.enable_vae_slicing()
85
+ pipe.enable_model_cpu_offload()
86
  return pipe
87
 
88
+ def get_cn():
89
+ global _CN
90
+ if _CN is None:
91
+ _CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
92
+ return _CN
 
 
 
93
 
94
+ def get_qrmon_txt2img_pipe(model_id: str):
95
+ if model_id not in _CN_TXT2IMG:
 
 
96
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
97
+ model_id,
98
+ controlnet=get_cn(),
99
+ torch_dtype=DTYPE,
100
+ safety_checker=None,
101
+ use_safetensors=True,
102
+ low_cpu_mem_usage=True,
103
  )
104
+ _CN_TXT2IMG[model_id] = _base_scheduler_for(pipe)
105
+ return _CN_TXT2IMG[model_id]
106
 
107
+ def get_qrmon_img2img_pipe(model_id: str):
108
+ if model_id not in _CN_IMG2IMG:
 
 
109
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
110
+ model_id,
111
+ controlnet=get_cn(),
112
+ torch_dtype=DTYPE,
113
+ safety_checker=None,
114
+ use_safetensors=True,
115
+ low_cpu_mem_usage=True,
116
  )
117
+ _CN_IMG2IMG[model_id] = _base_scheduler_for(pipe)
118
+ return _CN_IMG2IMG[model_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) --------
121
+ def _qr_txt2img_core(model_id: str,
122
+ url: str, style_prompt: str, negative: str,
123
+ steps: int, cfg: float, size: int, border: int,
124
+ qr_weight: float, seed: int,
125
+ use_hires: bool, hires_upscale: float, hires_strength: float,
126
+ repair_strength: float, feather: float):
127
 
128
  s = snap8(size)
129
 
 
135
  seed = random.randint(0, 2**31 - 1)
136
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
137
 
138
+ # --- Stage A: txt2img with ControlNet
139
+ pipe = get_qrmon_txt2img_pipe(model_id)
140
  if torch.cuda.is_available(): torch.cuda.empty_cache()
141
  gc.collect()
 
142
  with torch.autocast(device_type="cuda", dtype=DTYPE):
 
143
  out = pipe(
144
  prompt=str(style_prompt),
145
  negative_prompt=str(negative or ""),
146
+ image=qr_img, # control image for txt2img
147
+ controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well
148
+ control_guidance_start=0.0,
149
  control_guidance_end=1.0,
150
  num_inference_steps=int(steps),
151
  guidance_scale=float(cfg),
 
153
  generator=gen,
154
  )
155
  lowres = out.images[0]
156
+ lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather))
157
 
158
  # --- Optional Stage B: Hi-Res Fix (img2img with same QR)
159
  final = lowres
160
  if use_hires:
161
  up = max(1.0, min(2.0, float(hires_upscale)))
162
  W = snap8(int(s * up)); H = W
163
+ pipe2 = get_qrmon_img2img_pipe(model_id)
164
  if torch.cuda.is_available(): torch.cuda.empty_cache()
165
  gc.collect()
166
  with torch.autocast(device_type="cuda", dtype=DTYPE):
 
183
  final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
184
  return final, lowres, qr_img
185
 
186
+ # Wrappers for each tab (so Gradio can bind without passing the model id)
187
+ @spaces.GPU(duration=120)
188
+ def qr_txt2img_anything(*args):
189
+ return _qr_txt2img_core(BASE_MODELS["anything"], *args)
190
+
191
+ @spaces.GPU(duration=120)
192
+ def qr_txt2img_dream(*args):
193
+ return _qr_txt2img_core(BASE_MODELS["dream"], *args)
194
+
195
  # ---------- UI ----------
196
  with gr.Blocks() as demo:
197
+ gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)")
198
+
199
+ # ---- Tab 1: Anything v4.5 (anime/illustration) ----
200
+ with gr.Tab("Method 1 Anything v4.5"):
201
+ url1 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
202
+ s_prompt1 = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition")
203
+ s_negative1= gr.Textbox(label="Negative prompt", value="ugly, low quality, blurry, nsfw, watermark, text, low contrast, deformed, extra digits")
204
+ size1 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
205
+ steps1 = gr.Slider(10, 50, value=20, step=1, label="Steps")
206
+ cfg1 = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG")
207
+ border1 = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
208
+ qr_w1 = gr.Slider(0.6, 1.6, value=1.1, step=0.05, label="QR control weight")
209
+ seed1 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
210
+
211
+ use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
212
+ hires_up1 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
213
+ hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
214
+
215
+ repair1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
216
+ feather1 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
217
+
218
+ final_img1 = gr.Image(label="Final (or Hi-Res) image")
219
+ low_img1 = gr.Image(label="Low-res (Stage A) preview")
220
+ ctrl_img1 = gr.Image(label="Control QR used")
221
+
222
+ gr.Button("Generate with Anything v4.5").click(
223
+ qr_txt2img_anything,
224
+ [url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1,
225
+ use_hires1, hires_up1, hires_str1, repair1, feather1],
226
+ [final_img1, low_img1, ctrl_img1]
227
+ )
228
+
229
+ # ---- Tab 2: DreamShaper (general art/painterly) ----
230
+ with gr.Tab("Method 1 • DreamShaper 8"):
231
+ url2 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
232
+ s_prompt2 = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic")
233
+ s_negative2= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, watermark, text, bad anatomy")
234
+ size2 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)")
235
+ steps2 = gr.Slider(10, 50, value=24, step=1, label="Steps")
236
+ cfg2 = gr.Slider(1.0, 12.0, value=6.8, step=0.1, label="CFG")
237
+ border2 = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)")
238
+ qr_w2 = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight")
239
+ seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
240
+
241
+ use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
242
+ hires_up2 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
243
+ hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
244
+
245
+ repair2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
246
+ feather2 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
247
+
248
+ final_img2 = gr.Image(label="Final (or Hi-Res) image")
249
+ low_img2 = gr.Image(label="Low-res (Stage A) preview")
250
+ ctrl_img2 = gr.Image(label="Control QR used")
251
+
252
+ gr.Button("Generate with DreamShaper 8").click(
253
+ qr_txt2img_dream,
254
+ [url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2,
255
+ use_hires2, hires_up2, hires_str2, repair2, feather2],
256
+ [final_img2, low_img2, ctrl_img2]
257
  )
258
 
259
  if __name__ == "__main__":