Tanut commited on
Commit
56a99b7
·
1 Parent(s): c119da0
Files changed (1) hide show
  1. app.py +79 -55
app.py CHANGED
@@ -1,21 +1,46 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  import qrcode
5
  from qrcode.constants import ERROR_CORRECT_H
6
 
7
- # --- shared device/dtype ---
8
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
9
  dtype = torch.float16 if device != "cpu" else torch.float32
10
 
11
- # --- prompt-only SD pipe you already loaded as sd_pipe ---
12
- # from diffusers import StableDiffusionPipeline
13
- # sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
 
14
 
15
- # --- lazy ControlNet loader (canny) ---
16
- _cn = {"pipe": None, "canny": None}
17
- def _load_cn():
18
- if _cn["pipe"] is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
20
  from controlnet_aux import CannyDetector
21
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=dtype)
@@ -27,60 +52,59 @@ def _load_cn():
27
  ).to(device)
28
  pipe.enable_attention_slicing()
29
  pipe.enable_vae_slicing()
30
- _cn["pipe"], _cn["canny"] = pipe, CannyDetector()
31
- return _cn["pipe"], _cn["canny"]
32
-
33
- def _make_qr(url: str, size: int, border: int) -> Image.Image:
34
- qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=border)
35
- qr.add_data(url.strip()); qr.make(fit=True)
36
- img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
37
- return img.resize((size, size), resample=Image.NEAREST)
38
 
39
- def stylize_from_url(url: str, size: int, border: int,
40
- prompt: str, steps: int, guidance: float, seed: int,
41
- canny_low: int, canny_high: int):
42
- # 1) Make QR
43
- qr_img = _make_qr(url, size, border).convert("RGB")
44
- qr_512 = qr_img.resize((512, 512), Image.NEAREST)
45
 
46
- # 2) Canny edges
47
- pipe, canny = _load_cn()
48
- edges = canny(qr_512, low_threshold=int(canny_low), high_threshold=int(canny_high))
49
 
50
- # 3) Generate with ControlNet
51
  gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
52
  def run():
53
- return pipe(prompt=str(prompt),
54
- image=edges,
55
- num_inference_steps=int(steps),
56
- guidance_scale=float(guidance),
57
- generator=gen).images[0]
 
 
58
  if device in ("cuda", "mps"):
59
  with torch.autocast(device):
60
- out = run()
61
- else:
62
- out = run()
63
- return out, qr_512, edges
64
 
65
- # ====== UI: new tab ======
66
- with gr.Tab("QR Stylizer (ControlNet, auto‑QR)"):
67
- url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
68
- size = gr.Slider(256, 1024, value=512, step=64, label="QR size (px)")
69
- border = gr.Slider(0, 8, value=2, step=1, label="QR border (quiet zone)")
70
- prompt = gr.Textbox(label="Style Prompt",
71
- value="floral papercut style, high contrast, preserve sharp squares")
72
- steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
73
- guidance = gr.Slider(1, 12, value=7.5, step=0.1, label="Guidance Scale")
74
- seed = gr.Number(value=0, label="Seed (0=random)", precision=0)
75
- canny_l = gr.Slider(0, 255, value=100, step=1, label="Canny low")
76
- canny_h = gr.Slider(0, 255, value=200, step=1, label="Canny high")
77
 
78
- out_img = gr.Image(label="Stylized QR")
79
- out_qr = gr.Image(label="Generated QR (input)")
80
- out_edge = gr.Image(label="Canny edges (debug)")
 
 
 
 
81
 
82
- gr.Button("Stylize").click(
83
- stylize_from_url,
84
- inputs=[url, size, border, prompt, steps, guidance, seed, canny_l, canny_h],
85
- outputs=[out_img, out_qr, out_edge]
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
  from PIL import Image
5
  import qrcode
6
  from qrcode.constants import ERROR_CORRECT_H
7
 
8
+ # ========= Stable Diffusion (prompt-only) =========
9
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
10
  dtype = torch.float16 if device != "cpu" else torch.float32
11
 
12
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
13
+ "runwayml/stable-diffusion-v1-5",
14
+ torch_dtype=dtype
15
+ ).to(device)
16
 
17
+ def sd_generate(prompt, steps, guidance, seed):
18
+ gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
19
+ def run():
20
+ return sd_pipe(prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=gen).images[0]
21
+ if device in ("cuda", "mps"):
22
+ with torch.autocast(device):
23
+ return run()
24
+ return run()
25
+
26
+ # ========= QR Maker =========
27
+ def make_qr(url: str = "http://www.mybirdfire.com", size: int = 512, border: int = 2) -> Image.Image:
28
+ qr = qrcode.QRCode(
29
+ version=None,
30
+ error_correction=ERROR_CORRECT_H,
31
+ box_size=10,
32
+ border=border
33
+ )
34
+ qr.add_data(url.strip())
35
+ qr.make(fit=True)
36
+ img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
37
+ return img.resize((size, size), resample=Image.NEAREST)
38
+
39
+ # ========= ControlNet Stylizer (prompt + QR) =========
40
+ # lazy-load to speed initial startup
41
+ _cn_loaded = {"pipe": None}
42
+ def _load_controlnet():
43
+ if _cn_loaded["pipe"] is None:
44
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
45
  from controlnet_aux import CannyDetector
46
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=dtype)
 
52
  ).to(device)
53
  pipe.enable_attention_slicing()
54
  pipe.enable_vae_slicing()
55
+ _cn_loaded["pipe"] = pipe
56
+ _cn_loaded["canny"] = CannyDetector()
57
+ return _cn_loaded["pipe"], _cn_loaded["canny"]
 
 
 
 
 
58
 
59
+ def stylize_qr(prompt: str, steps: int, guidance: float, seed: int, canny_low: int, canny_high: int):
60
+ # Always make QR from URL
61
+ qr_image = make_qr("http://www.mybirdfire.com", size=512, border=2)
 
 
 
62
 
63
+ pipe, canny = _load_controlnet()
64
+ edges = canny(qr_image, low_threshold=int(canny_low), high_threshold=int(canny_high))
 
65
 
 
66
  gen = torch.Generator(device=device).manual_seed(int(seed)) if int(seed) != 0 else None
67
  def run():
68
+ return pipe(
69
+ prompt=str(prompt),
70
+ image=edges,
71
+ num_inference_steps=int(steps),
72
+ guidance_scale=float(guidance),
73
+ generator=gen
74
+ ).images[0]
75
  if device in ("cuda", "mps"):
76
  with torch.autocast(device):
77
+ return run()
78
+ return run()
 
 
79
 
80
+ # ========= UI =========
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown("## Stable Diffusion + QR Code + ControlNet (step by step)")
 
 
 
 
 
 
 
 
 
83
 
84
+ with gr.Tab("Stable Diffusion (prompt → image)"):
85
+ prompt = gr.Textbox(label="Prompt", value="A fantasy castle at sunset")
86
+ steps = gr.Slider(10, 50, value=30, label="Steps", step=1)
87
+ cfg = gr.Slider(1, 12, value=7.5, label="Guidance Scale", step=0.1)
88
+ seed = gr.Number(value=0, label="Seed (0 = random)", precision=0)
89
+ out_sd = gr.Image(label="Generated Image")
90
+ gr.Button("Generate").click(sd_generate, [prompt, steps, cfg, seed], out_sd)
91
 
92
+ with gr.Tab("QR Maker (mybirdfire)"):
93
+ url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
94
+ size = gr.Slider(256, 1024, value=512, step=64, label="Size (px)")
95
+ quiet = gr.Slider(0, 8, value=2, step=1, label="Border (quiet zone)")
96
+ out_qr = gr.Image(label="QR Code", type="pil")
97
+ gr.Button("Generate QR").click(make_qr, [url, size, quiet], out_qr)
98
+
99
+ with gr.Tab("QR Stylizer (ControlNet)"):
100
+ s_prompt = gr.Textbox(label="Style Prompt", value="floral papercut style, high contrast, preserve sharp squares")
101
+ s_steps = gr.Slider(10, 50, value=30, label="Steps", step=1)
102
+ s_cfg = gr.Slider(1, 12, value=7.5, label="Guidance Scale", step=0.1)
103
+ s_seed = gr.Number(value=0, label="Seed (0 = random)", precision=0)
104
+ canny_l = gr.Slider(0, 255, value=100, step=1, label="Canny low")
105
+ canny_h = gr.Slider(0, 255, value=200, step=1, label="Canny high")
106
+ out_styl = gr.Image(label="Stylized QR")
107
+ gr.Button("Stylize").click(stylize_qr, [s_prompt, s_steps, s_cfg, s_seed, canny_l, canny_h], out_styl)
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()