Tanut commited on
Commit
ee61c84
·
1 Parent(s): b593ec0

Testing img2img

Browse files
Files changed (1) hide show
  1. app.py +77 -50
app.py CHANGED
@@ -8,6 +8,7 @@ from qrcode.constants import ERROR_CORRECT_H
8
  from diffusers import (
9
  StableDiffusionPipeline,
10
  StableDiffusionControlNetPipeline,
 
11
  ControlNetModel,
12
  DPMSolverMultistepScheduler,
13
  )
@@ -16,7 +17,7 @@ from diffusers import (
16
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
17
 
18
  MODEL_ID = "runwayml/stable-diffusion-v1-5"
19
- CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster" # v2 on the repo
20
  DTYPE = torch.float16
21
 
22
  # ---------- helpers ----------
@@ -41,7 +42,7 @@ def normalize_color(c):
41
  return "white"
42
 
43
  def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#808080", blur_radius=1.2):
44
- # Mid-gray background improves blending & scan rate with QR-Monster v2. :contentReference[oaicite:1]{index=1}
45
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
46
  qr.add_data(url.strip()); qr.make(fit=True)
47
  img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
@@ -51,7 +52,7 @@ def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#8
51
  return img
52
 
53
  def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.6, feather: float = 1.0) -> Image.Image:
54
- """Gently push ControlNet-required blacks/whites for scannability (simple post 'repair')."""
55
  if strength <= 0: return stylized
56
  q = qr_img.convert("L")
57
  black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
@@ -65,7 +66,15 @@ def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: fl
65
 
66
  # ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
67
  _SD = None
68
- _CN = None
 
 
 
 
 
 
 
 
69
 
70
  def get_sd_pipe():
71
  global _SD
@@ -77,16 +86,13 @@ def get_sd_pipe():
77
  use_safetensors=True,
78
  low_cpu_mem_usage=True,
79
  )
80
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
81
- pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
82
- )
83
- pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
84
- _SD = pipe
85
  return _SD
86
 
87
- def get_qrmon_pipe():
88
- global _CN
89
- if _CN is None:
 
90
  cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
91
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
92
  MODEL_ID,
@@ -96,12 +102,24 @@ def get_qrmon_pipe():
96
  use_safetensors=True,
97
  low_cpu_mem_usage=True,
98
  )
99
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
100
- pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
- pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
103
- _CN = pipe
104
- return _CN
105
 
106
  # ---------- ZeroGPU tasks ----------
107
  @spaces.GPU(duration=120)
@@ -127,48 +145,55 @@ def txt2img(prompt: str, negative: str, steps: int, cfg: float, width: int, heig
127
  @spaces.GPU(duration=120)
128
  def qr_stylize(url: str, style_prompt: str, negative: str, steps: int, cfg: float,
129
  size: int, border: int, back_color: str, blur: float,
130
- qr_weight: float, repair_strength: float, feather: float, seed: int):
131
- pipe = get_qrmon_pipe()
132
  s = snap8(size)
133
- qr_img = make_qr(url=url, size=s, border=int(border), back_color=back_color, blur_radius=float(blur))
134
 
 
 
135
  if int(seed) < 0:
136
  seed = random.randint(0, 2**31 - 1)
137
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
138
 
139
- # Tip from the article: don't stuff "QR code" into the prompt; let ControlNet shape it. :contentReference[oaicite:2]{index=2}
140
  if torch.cuda.is_available(): torch.cuda.empty_cache()
141
  gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
 
 
143
  with torch.autocast(device_type="cuda", dtype=DTYPE):
144
- try:
145
- # diffusers ≥ 0.30.x uses `image=` for ControlNet’s conditioning input
146
- out = pipe(
147
- prompt=str(style_prompt),
148
- negative_prompt=str(negative or ""),
149
- image=qr_img, # <-- use `image`
150
- controlnet_conditioning_scale=float(qr_weight),
151
- num_inference_steps=int(steps),
152
- guidance_scale=float(cfg),
153
- width=s, height=s,
154
- generator=gen,
155
- )
156
- except TypeError:
157
- # fallback for older versions that still expect `control_image=`
158
- out = pipe(
159
- prompt=str(style_prompt),
160
- negative_prompt=str(negative or ""),
161
- control_image=qr_img,
162
- controlnet_conditioning_scale=float(qr_weight),
163
- num_inference_steps=int(steps),
164
- guidance_scale=float(cfg),
165
- width=s, height=s,
166
- generator=gen,
167
- )
168
 
169
  img = out.images[0]
170
  img = enforce_qr_contrast(img, qr_img, strength=float(repair_strength), feather=float(feather))
171
- return img, qr_img
172
 
173
  # ---------- UI ----------
174
  with gr.Blocks() as demo:
@@ -185,26 +210,28 @@ with gr.Blocks() as demo:
185
  out_img = gr.Image(label="Image", interactive=False)
186
  gr.Button("Generate").click(txt2img, [prompt, negative, steps, cfg, width, height, seed], out_img)
187
 
188
- with gr.Tab("QR Code Stylizer (ControlNet Monster)"):
189
  url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
190
  s_prompt = gr.Textbox(label="Style prompt (no 'QR code' needed)", value="baroque palace interior, intricate roots, dramatic lighting, ultra detailed")
191
  s_negative= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, worst quality, watermark, text")
192
  size = gr.Slider(384, 1024, value=768, step=64, label="Canvas (px)")
193
- steps2 = gr.Slider(10, 50, value=28, step=1, label="Steps")
194
  cfg2 = gr.Slider(1.0, 12.0, value=6.5, step=0.1, label="CFG")
195
  border = gr.Slider(4, 20, value=12, step=1, label="QR border (quiet zone)")
196
  back_col = gr.ColorPicker(value="#808080", label="QR background")
197
  blur = gr.Slider(0.0, 3.0, value=1.2, step=0.1, label="Soften control (blur)")
198
  qr_w = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight")
 
199
  repair = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Post repair strength")
200
  feather = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
201
  seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
202
  final_img = gr.Image(label="Final stylized QR")
203
  ctrl_img = gr.Image(label="Control QR used")
 
204
  gr.Button("Stylize QR").click(
205
  qr_stylize,
206
- [url, s_prompt, s_negative, steps2, cfg2, size, border, back_col, blur, qr_w, repair, feather, seed2],
207
- [final_img, ctrl_img]
208
  )
209
 
210
  if __name__ == "__main__":
 
8
  from diffusers import (
9
  StableDiffusionPipeline,
10
  StableDiffusionControlNetPipeline,
11
+ StableDiffusionControlNetImg2ImgPipeline, # NEW: img2img pipeline
12
  ControlNetModel,
13
  DPMSolverMultistepScheduler,
14
  )
 
17
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
18
 
19
  MODEL_ID = "runwayml/stable-diffusion-v1-5"
20
+ CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster"
21
  DTYPE = torch.float16
22
 
23
  # ---------- helpers ----------
 
42
  return "white"
43
 
44
  def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#808080", blur_radius=1.2):
45
+ # Mid-gray background improves blending & scan rate with QR-Monster.
46
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
47
  qr.add_data(url.strip()); qr.make(fit=True)
48
  img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
 
52
  return img
53
 
54
  def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.6, feather: float = 1.0) -> Image.Image:
55
+ """Gently push ControlNet-required blacks/whites for scannability."""
56
  if strength <= 0: return stylized
57
  q = qr_img.convert("L")
58
  black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
 
66
 
67
  # ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ----------
68
  _SD = None
69
+ _CN_TXT2IMG = None
70
+ _CN_IMG2IMG = None
71
+
72
+ def _base_scheduler_for(pipe):
73
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
74
+ pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
75
+ )
76
+ pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload()
77
+ return pipe
78
 
79
  def get_sd_pipe():
80
  global _SD
 
86
  use_safetensors=True,
87
  low_cpu_mem_usage=True,
88
  )
89
+ _SD = _base_scheduler_for(pipe)
 
 
 
 
90
  return _SD
91
 
92
+ def get_qrmon_txt2img_pipe():
93
+ """(kept for completeness; not used in the two-stage flow)"""
94
+ global _CN_TXT2IMG
95
+ if _CN_TXT2IMG is None:
96
  cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
97
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
98
  MODEL_ID,
 
102
  use_safetensors=True,
103
  low_cpu_mem_usage=True,
104
  )
105
+ _CN_TXT2IMG = _base_scheduler_for(pipe)
106
+ return _CN_TXT2IMG
107
+
108
+ def get_qrmon_img2img_pipe():
109
+ """This is the pipeline we want for stage B."""
110
+ global _CN_IMG2IMG
111
+ if _CN_IMG2IMG is None:
112
+ cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
113
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
114
+ MODEL_ID,
115
+ controlnet=cn,
116
+ torch_dtype=DTYPE,
117
+ safety_checker=None,
118
+ use_safetensors=True,
119
+ low_cpu_mem_usage=True,
120
  )
121
+ _CN_IMG2IMG = _base_scheduler_for(pipe)
122
+ return _CN_IMG2IMG
 
123
 
124
  # ---------- ZeroGPU tasks ----------
125
  @spaces.GPU(duration=120)
 
145
  @spaces.GPU(duration=120)
146
  def qr_stylize(url: str, style_prompt: str, negative: str, steps: int, cfg: float,
147
  size: int, border: int, back_color: str, blur: float,
148
+ qr_weight: float, repair_strength: float, feather: float, seed: int,
149
+ denoise: float = 0.45):
150
  s = snap8(size)
 
151
 
152
+ # --- Stage A: base art (txt2img) ---
153
+ sd = get_sd_pipe()
154
  if int(seed) < 0:
155
  seed = random.randint(0, 2**31 - 1)
156
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
157
 
 
158
  if torch.cuda.is_available(): torch.cuda.empty_cache()
159
  gc.collect()
160
+ with torch.autocast(device_type="cuda", dtype=DTYPE):
161
+ base = sd(
162
+ prompt=str(style_prompt), # don't include "QR code" here
163
+ negative_prompt=str(negative or ""),
164
+ num_inference_steps=max(int(steps)//2, 12),
165
+ guidance_scale=float(cfg),
166
+ width=s, height=s,
167
+ generator=gen,
168
+ ).images[0]
169
+
170
+ # Control image (QR)
171
+ qr_img = make_qr(url=url, size=s, border=int(border),
172
+ back_color=back_color, blur_radius=float(blur))
173
 
174
+ # --- Stage B: ControlNet img2img (QR Monster) ---
175
+ pipe = get_qrmon_img2img_pipe()
176
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
177
+ gc.collect()
178
  with torch.autocast(device_type="cuda", dtype=DTYPE):
179
+ out = pipe(
180
+ prompt=str(style_prompt),
181
+ negative_prompt=str(negative or ""),
182
+ image=base, # init image (img2img)
183
+ control_image=qr_img, # control image (QR)
184
+ strength=float(denoise), # 0.3–0.6 keeps composition
185
+ controlnet_conditioning_scale=float(qr_weight),
186
+ control_guidance_start=0.05,
187
+ control_guidance_end=0.95,
188
+ num_inference_steps=int(steps),
189
+ guidance_scale=float(cfg),
190
+ width=s, height=s,
191
+ generator=gen,
192
+ )
 
 
 
 
 
 
 
 
 
 
193
 
194
  img = out.images[0]
195
  img = enforce_qr_contrast(img, qr_img, strength=float(repair_strength), feather=float(feather))
196
+ return img, qr_img, base
197
 
198
  # ---------- UI ----------
199
  with gr.Blocks() as demo:
 
210
  out_img = gr.Image(label="Image", interactive=False)
211
  gr.Button("Generate").click(txt2img, [prompt, negative, steps, cfg, width, height, seed], out_img)
212
 
213
+ with gr.Tab("QR Code Stylizer (ControlNet Monster — two-stage)"):
214
  url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
215
  s_prompt = gr.Textbox(label="Style prompt (no 'QR code' needed)", value="baroque palace interior, intricate roots, dramatic lighting, ultra detailed")
216
  s_negative= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, worst quality, watermark, text")
217
  size = gr.Slider(384, 1024, value=768, step=64, label="Canvas (px)")
218
+ steps2 = gr.Slider(10, 60, value=28, step=1, label="Total steps")
219
  cfg2 = gr.Slider(1.0, 12.0, value=6.5, step=0.1, label="CFG")
220
  border = gr.Slider(4, 20, value=12, step=1, label="QR border (quiet zone)")
221
  back_col = gr.ColorPicker(value="#808080", label="QR background")
222
  blur = gr.Slider(0.0, 3.0, value=1.2, step=0.1, label="Soften control (blur)")
223
  qr_w = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight")
224
+ denoise = gr.Slider(0.2, 0.8, value=0.45, step=0.01, label="Denoising strength (Stage B)")
225
  repair = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Post repair strength")
226
  feather = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
227
  seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
228
  final_img = gr.Image(label="Final stylized QR")
229
  ctrl_img = gr.Image(label="Control QR used")
230
+ base_img = gr.Image(label="Base art (Stage A)")
231
  gr.Button("Stylize QR").click(
232
  qr_stylize,
233
+ [url, s_prompt, s_negative, steps2, cfg2, size, border, back_col, blur, qr_w, repair, feather, seed2, denoise],
234
+ [final_img, ctrl_img, base_img]
235
  )
236
 
237
  if __name__ == "__main__":