Tanut commited on
Commit
3330f45
·
1 Parent(s): e9788d0

Test ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +147 -27
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,16 +1,75 @@
1
- import gc, random, os
2
  import gradio as gr
3
  import torch, spaces
4
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  MODEL_ID = "runwayml/stable-diffusion-v1-5"
 
7
  DTYPE = torch.float16
8
- _PIPE = None
9
 
10
- def get_pipe():
11
- global _PIPE
12
- if _PIPE is None:
13
- # Build on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  pipe = StableDiffusionPipeline.from_pretrained(
15
  MODEL_ID,
16
  torch_dtype=DTYPE,
@@ -21,29 +80,39 @@ def get_pipe():
21
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(
22
  pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
23
  )
24
- pipe.enable_attention_slicing()
25
- pipe.enable_vae_slicing()
26
- pipe.enable_model_cpu_offload()
27
- _PIPE = pipe
28
- return _PIPE
29
 
30
- def snap8(x: int) -> int:
31
- x = max(256, min(1024, int(x)))
32
- return x - (x % 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  @spaces.GPU(duration=120)
35
- def generate(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int):
36
- pipe = get_pipe() # stays CPU/offloaded until now
37
  w, h = snap8(width), snap8(height)
38
-
39
  if int(seed) < 0:
40
  seed = random.randint(0, 2**31 - 1)
41
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
42
-
43
- if torch.cuda.is_available():
44
- torch.cuda.empty_cache()
45
  gc.collect()
46
-
47
  with torch.autocast(device_type="cuda", dtype=DTYPE):
48
  out = pipe(
49
  prompt=str(prompt),
@@ -55,8 +124,40 @@ def generate(prompt: str, negative: str, steps: int, cfg: float, width: int, hei
55
  )
56
  return out.images[0]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# ZeroGPU + SD1.5 (minimal)")
60
 
61
  with gr.Tab("Text → Image"):
62
  prompt = gr.Textbox(label="Prompt", value="a cozy reading nook, warm sunlight, cinematic lighting, highly detailed")
@@ -67,10 +168,29 @@ with gr.Blocks() as demo:
67
  height = gr.Slider(256, 1024, value=640, step=16, label="Height")
68
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
69
  out_img = gr.Image(label="Image", interactive=False)
70
- gr.Button("Generate").click(generate, [prompt, negative, steps, cfg, width, height, seed], out_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  if __name__ == "__main__":
73
- # On Spaces: keep it simple; don’t pass odd kwargs.
74
- # If you see “localhost is not accessible”, add share=True.
75
  demo.queue(max_size=12).launch()
76
-
 
1
+ import os, gc, random, re
2
  import gradio as gr
3
  import torch, spaces
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import qrcode
7
+ from qrcode.constants import ERROR_CORRECT_H
8
+ from diffusers import (
9
+ StableDiffusionPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ ControlNetModel,
12
+ DPMSolverMultistepScheduler,
13
+ )
14
+
15
+ # Optional: silence matplotlib cache warning in Spaces
16
+ os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
17
 
18
  MODEL_ID = "runwayml/stable-diffusion-v1-5"
19
+ CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster" # v2 on the repo
20
  DTYPE = torch.float16
 
21
 
22
+ # ---------- helpers ----------
23
+ def snap8(x: int) -> int:
24
+ x = max(256, min(1024, int(x)))
25
+ return x - (x % 8)
26
+
27
+ def normalize_color(c):
28
+ if c is None: return "white"
29
+ if isinstance(c, (tuple, list)):
30
+ r, g, b = (int(max(0, min(255, round(float(x))))) for x in c[:3]); return (r, g, b)
31
+ if isinstance(c, str):
32
+ s = c.strip()
33
+ if s.startswith("#"): return s
34
+ m = re.match(r"rgba?\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)", s, re.IGNORECASE)
35
+ if m:
36
+ r = int(max(0, min(255, round(float(m.group(1))))))
37
+ g = int(max(0, min(255, round(float(m.group(2))))))
38
+ b = int(max(0, min(255, round(float(m.group(3))))))
39
+ return (r, g, b)
40
+ return s
41
+ return "white"
42
+
43
+ def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#808080", blur_radius=1.2):
44
+ # Mid-gray background improves blending & scan rate with QR-Monster v2. :contentReference[oaicite:1]{index=1}
45
+ qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
46
+ qr.add_data(url.strip()); qr.make(fit=True)
47
+ img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
48
+ img = img.resize((int(size), int(size)), Image.NEAREST)
49
+ if blur_radius and blur_radius > 0:
50
+ img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius)))
51
+ return img
52
+
53
+ def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.6, feather: float = 1.0) -> Image.Image:
54
+ """Gently push ControlNet-required blacks/whites for scannability (simple post 'repair')."""
55
+ if strength <= 0: return stylized
56
+ q = qr_img.convert("L")
57
+ black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
58
+ black = np.asarray(black_mask, dtype=np.float32) / 255.0
59
+ white = 1.0 - black
60
+ s = np.asarray(stylized.convert("RGB"), dtype=np.float32) / 255.0
61
+ s = s * (1.0 - float(strength) * black[..., None]) # deepen blacks
62
+ s = s + (1.0 - s) * (float(strength) * 0.85 * white[..., None]) # lift whites
63
+ s = np.clip(s, 0.0, 1.0)
64
+ return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB")
65
+
66
+ # ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
67
+ _SD = None
68
+ _CN = None
69
+
70
+ def get_sd_pipe():
71
+ global _SD
72
+ if _SD is None:
73
  pipe = StableDiffusionPipeline.from_pretrained(
74
  MODEL_ID,
75
  torch_dtype=DTYPE,
 
80
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(
81
  pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
82
  )
83
+ pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
84
+ _SD = pipe
85
+ return _SD
 
 
86
 
87
+ def get_qrmon_pipe():
88
+ global _CN
89
+ if _CN is None:
90
+ cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
91
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
92
+ MODEL_ID,
93
+ controlnet=cn,
94
+ torch_dtype=DTYPE,
95
+ safety_checker=None,
96
+ use_safetensors=True,
97
+ low_cpu_mem_usage=True,
98
+ )
99
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
100
+ pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
101
+ )
102
+ pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
103
+ _CN = pipe
104
+ return _CN
105
 
106
+ # ---------- ZeroGPU tasks ----------
107
  @spaces.GPU(duration=120)
108
+ def txt2img(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int):
109
+ pipe = get_sd_pipe()
110
  w, h = snap8(width), snap8(height)
 
111
  if int(seed) < 0:
112
  seed = random.randint(0, 2**31 - 1)
113
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
114
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
 
115
  gc.collect()
 
116
  with torch.autocast(device_type="cuda", dtype=DTYPE):
117
  out = pipe(
118
  prompt=str(prompt),
 
124
  )
125
  return out.images[0]
126
 
127
+ @spaces.GPU(duration=120)
128
+ def qr_stylize(url: str, style_prompt: str, negative: str, steps: int, cfg: float,
129
+ size: int, border: int, back_color: str, blur: float,
130
+ qr_weight: float, repair_strength: float, feather: float, seed: int):
131
+ pipe = get_qrmon_pipe()
132
+ s = snap8(size)
133
+ qr_img = make_qr(url=url, size=s, border=int(border), back_color=back_color, blur_radius=float(blur))
134
+
135
+ if int(seed) < 0:
136
+ seed = random.randint(0, 2**31 - 1)
137
+ gen = torch.Generator(device="cuda").manual_seed(int(seed))
138
+
139
+ # Tip from the article: don't stuff "QR code" into the prompt; let ControlNet shape it. :contentReference[oaicite:2]{index=2}
140
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
141
+ gc.collect()
142
+
143
+ with torch.autocast(device_type="cuda", dtype=DTYPE):
144
+ out = pipe(
145
+ prompt=str(style_prompt),
146
+ negative_prompt=str(negative or ""),
147
+ control_image=qr_img,
148
+ controlnet_conditioning_scale=float(qr_weight),
149
+ num_inference_steps=int(steps),
150
+ guidance_scale=float(cfg),
151
+ width=s, height=s,
152
+ generator=gen,
153
+ )
154
+ img = out.images[0]
155
+ img = enforce_qr_contrast(img, qr_img, strength=float(repair_strength), feather=float(feather))
156
+ return img, qr_img
157
+
158
+ # ---------- UI ----------
159
  with gr.Blocks() as demo:
160
+ gr.Markdown("# ZeroGPU Stable Diffusion + AI QR Codes (Monster v2)")
161
 
162
  with gr.Tab("Text → Image"):
163
  prompt = gr.Textbox(label="Prompt", value="a cozy reading nook, warm sunlight, cinematic lighting, highly detailed")
 
168
  height = gr.Slider(256, 1024, value=640, step=16, label="Height")
169
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
170
  out_img = gr.Image(label="Image", interactive=False)
171
+ gr.Button("Generate").click(txt2img, [prompt, negative, steps, cfg, width, height, seed], out_img)
172
+
173
+ with gr.Tab("QR Code Stylizer (ControlNet Monster)"):
174
+ url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
175
+ s_prompt = gr.Textbox(label="Style prompt (no 'QR code' needed)", value="baroque palace interior, intricate roots, dramatic lighting, ultra detailed")
176
+ s_negative= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, worst quality, watermark, text")
177
+ size = gr.Slider(384, 1024, value=768, step=64, label="Canvas (px)")
178
+ steps2 = gr.Slider(10, 50, value=28, step=1, label="Steps")
179
+ cfg2 = gr.Slider(1.0, 12.0, value=6.5, step=0.1, label="CFG")
180
+ border = gr.Slider(4, 20, value=12, step=1, label="QR border (quiet zone)")
181
+ back_col = gr.ColorPicker(value="#808080", label="QR background")
182
+ blur = gr.Slider(0.0, 3.0, value=1.2, step=0.1, label="Soften control (blur)")
183
+ qr_w = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight")
184
+ repair = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Post repair strength")
185
+ feather = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
186
+ seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
187
+ final_img = gr.Image(label="Final stylized QR")
188
+ ctrl_img = gr.Image(label="Control QR used")
189
+ gr.Button("Stylize QR").click(
190
+ qr_stylize,
191
+ [url, s_prompt, s_negative, steps2, cfg2, size, border, back_col, blur, qr_w, repair, feather, seed2],
192
+ [final_img, ctrl_img]
193
+ )
194
 
195
  if __name__ == "__main__":
 
 
196
  demo.queue(max_size=12).launch()
 
requirements.txt CHANGED
@@ -8,5 +8,6 @@ gradio>=4.44.1
8
  pydantic==2.10.6
9
  huggingface_hub==0.29.3
10
  spaces
11
- qrcode[pil]
12
- opencv-python
 
 
8
  pydantic==2.10.6
9
  huggingface_hub==0.29.3
10
  spaces
11
+ **qrcode[pil]**
12
+ **Pillow**
13
+ **numpy==1.26.4** # ensure pinned (yes, twice is harmless)