Tanut commited on
Commit
3a24bb3
·
1 Parent(s): 493d820

ControlNet Blend the image to QRCode

Browse files
Files changed (2) hide show
  1. app.py +61 -10
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,11 +5,11 @@ from PIL import Image
5
  import qrcode
6
  from qrcode.constants import ERROR_CORRECT_H
7
 
8
- # -------- Stable Diffusion setup --------
9
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
10
  dtype = torch.float16 if device != "cpu" else torch.float32
11
 
12
- pipe = StableDiffusionPipeline.from_pretrained(
13
  "runwayml/stable-diffusion-v1-5",
14
  torch_dtype=dtype
15
  ).to(device)
@@ -17,19 +17,17 @@ pipe = StableDiffusionPipeline.from_pretrained(
17
  def sd_generate(prompt, steps, guidance, seed):
18
  gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
19
  def run():
20
- return pipe(prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=gen).images[0]
21
- # autocast only where supported
22
  if device in ("cuda", "mps"):
23
  with torch.autocast(device):
24
  return run()
25
- else:
26
- return run()
27
 
28
- # -------- QR code helper --------
29
  def make_qr(url: str = "http://www.mybirdfire.com", size: int = 512, border: int = 2) -> Image.Image:
30
  qr = qrcode.QRCode(
31
  version=None,
32
- error_correction=ERROR_CORRECT_H, # high EC to survive stylization later
33
  box_size=10,
34
  border=border
35
  )
@@ -38,9 +36,51 @@ def make_qr(url: str = "http://www.mybirdfire.com", size: int = 512, border: int
38
  img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
39
  return img.resize((size, size), resample=Image.NEAREST)
40
 
41
- # -------- UI --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  with gr.Blocks() as demo:
43
- gr.Markdown("## Stable Diffusion + QR (step by step)")
44
 
45
  with gr.Tab("Stable Diffusion (prompt → image)"):
46
  prompt = gr.Textbox(label="Prompt", value="A fantasy castle at sunset")
@@ -57,5 +97,16 @@ with gr.Blocks() as demo:
57
  out_qr = gr.Image(label="QR Code", type="pil")
58
  gr.Button("Generate QR").click(make_qr, [url, size, quiet], out_qr)
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
  demo.launch()
 
5
  import qrcode
6
  from qrcode.constants import ERROR_CORRECT_H
7
 
8
+ # ========= Stable Diffusion (prompt-only) =========
9
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
10
  dtype = torch.float16 if device != "cpu" else torch.float32
11
 
12
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
13
  "runwayml/stable-diffusion-v1-5",
14
  torch_dtype=dtype
15
  ).to(device)
 
17
  def sd_generate(prompt, steps, guidance, seed):
18
  gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
19
  def run():
20
+ return sd_pipe(prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=gen).images[0]
 
21
  if device in ("cuda", "mps"):
22
  with torch.autocast(device):
23
  return run()
24
+ return run()
 
25
 
26
+ # ========= QR Maker =========
27
  def make_qr(url: str = "http://www.mybirdfire.com", size: int = 512, border: int = 2) -> Image.Image:
28
  qr = qrcode.QRCode(
29
  version=None,
30
+ error_correction=ERROR_CORRECT_H,
31
  box_size=10,
32
  border=border
33
  )
 
36
  img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
37
  return img.resize((size, size), resample=Image.NEAREST)
38
 
39
+ # ========= ControlNet Stylizer (prompt + QR) =========
40
+ # lazy-load to speed initial startup
41
+ _cn_loaded = {"pipe": None}
42
+ def _load_controlnet():
43
+ if _cn_loaded["pipe"] is None:
44
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
45
+ from controlnet_aux import CannyDetector
46
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=dtype)
47
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
48
+ "runwayml/stable-diffusion-v1-5",
49
+ controlnet=controlnet,
50
+ torch_dtype=dtype,
51
+ safety_checker=None
52
+ ).to(device)
53
+ pipe.enable_attention_slicing()
54
+ pipe.enable_vae_slicing()
55
+ _cn_loaded["pipe"] = pipe
56
+ _cn_loaded["canny"] = CannyDetector()
57
+ return _cn_loaded["pipe"], _cn_loaded["canny"]
58
+
59
+ def stylize_qr(prompt: str, qr_image: Image.Image, steps: int, guidance: float, seed: int, canny_low: int, canny_high: int):
60
+ if qr_image is None:
61
+ raise gr.Error("Please provide a QR image (use the QR Maker tab first or upload one).")
62
+ pipe, canny = _load_controlnet()
63
+ # ensure 512x512 for speed/quality
64
+ qr_img = qr_image.convert("RGB").resize((512, 512), Image.NEAREST)
65
+ edges = canny(qr_img, low_threshold=int(canny_low), high_threshold=int(canny_high))
66
+
67
+ gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
68
+ def run():
69
+ return pipe(
70
+ prompt=str(prompt),
71
+ image=edges,
72
+ num_inference_steps=int(steps),
73
+ guidance_scale=float(guidance),
74
+ generator=gen
75
+ ).images[0]
76
+ if device in ("cuda", "mps"):
77
+ with torch.autocast(device):
78
+ return run()
79
+ return run()
80
+
81
+ # ========= UI =========
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("## Stable Diffusion + QR Code + ControlNet (step by step)")
84
 
85
  with gr.Tab("Stable Diffusion (prompt → image)"):
86
  prompt = gr.Textbox(label="Prompt", value="A fantasy castle at sunset")
 
97
  out_qr = gr.Image(label="QR Code", type="pil")
98
  gr.Button("Generate QR").click(make_qr, [url, size, quiet], out_qr)
99
 
100
+ with gr.Tab("QR Stylizer (ControlNet)"):
101
+ s_prompt = gr.Textbox(label="Style Prompt", value="floral papercut style, high contrast, preserve sharp squares")
102
+ s_steps = gr.Slider(10, 50, value=30, label="Steps", step=1)
103
+ s_cfg = gr.Slider(1, 12, value=7.5, label="Guidance Scale", step=0.1)
104
+ s_seed = gr.Number(value=0, label="Seed (0 = random)", precision=0)
105
+ canny_l = gr.Slider(0, 255, value=100, step=1, label="Canny low")
106
+ canny_h = gr.Slider(0, 255, value=200, step=1, label="Canny high")
107
+ qr_in = gr.Image(label="QR Input (use output from QR Maker or upload)", type="pil")
108
+ out_styl = gr.Image(label="Stylized QR")
109
+ gr.Button("Stylize").click(stylize_qr, [s_prompt, qr_in, s_steps, s_cfg, s_seed, canny_l, canny_h], out_styl)
110
+
111
  if __name__ == "__main__":
112
  demo.launch()
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  accelerate
5
  safetensors
6
  gradio
7
- qrcode[pil]
 
 
4
  accelerate
5
  safetensors
6
  gradio
7
+ qrcode[pil]
8
+ controlnet-aux