Tanut
Fix bug2
a2495ec
raw
history blame
7.69 kB
import os, gc, random
import gradio as gr
import numpy as np
from PIL import Image
import qrcode
from qrcode.constants import ERROR_CORRECT_H
import torch
import spaces # <- ZeroGPU decorator
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
)
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from controlnet_aux import CannyDetector
# -----------------------------
# Versions / env
# -----------------------------
TORCH_DTYPE = torch.float16 # Spaces GPU slice supports fp16 well
# Optional (private models): set HF_TOKEN in Space secrets
HF_TOKEN = os.getenv("HF_TOKEN")
AUTH = {"token": HF_TOKEN} if HF_TOKEN else {}
# -----------------------------
# Global caches (lazy)
# -----------------------------
_sd_txt = {"pipe": None}
_sd_cn = {"pipe": None, "canny": None}
BASE_15 = "runwayml/stable-diffusion-v1-5"
CN_CANNY_15 = "lllyasviel/sd-controlnet-canny"
CN_TILE_15 = "lllyasviel/control_v11f1e_sd15_tile"
NEG_DEFAULT = "lowres, low contrast, blurry, jpeg artifacts, worst quality, bad anatomy, extra digits"
# -----------------------------
# QR maker (unchanged behavior)
# -----------------------------
def make_qr(url: str = "http://www.mybirdfire.com", size: int = 512, border: int = 4) -> Image.Image:
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
qr.add_data(url.strip())
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
return img.resize((int(size), int(size)), resample=Image.NEAREST)
# -----------------------------
# Lazy loaders (Spaces-safe)
# -----------------------------
def _get_sd15_txt2img():
if _sd_txt["pipe"] is None:
pipe = StableDiffusionPipeline.from_pretrained(
BASE_15,
torch_dtype=TORCH_DTYPE,
safety_checker=None,
use_safetensors=True,
low_cpu_mem_usage=True,
**AUTH
)
# Memory savers β€” ok to call before GPU is attached
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
_sd_txt["pipe"] = pipe
return _sd_txt["pipe"]
def _get_sd15_canny_tile():
if _sd_cn["pipe"] is None:
canny = ControlNetModel.from_pretrained(CN_CANNY_15, torch_dtype=TORCH_DTYPE, use_safetensors=True, **AUTH)
tile = ControlNetModel.from_pretrained(CN_TILE_15, torch_dtype=TORCH_DTYPE, use_safetensors=True, **AUTH)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_15,
controlnet=[canny, tile],
torch_dtype=TORCH_DTYPE,
safety_checker=None,
use_safetensors=True,
low_cpu_mem_usage=True,
**AUTH
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
_sd_cn["pipe"] = pipe
_sd_cn["canny"] = CannyDetector()
return _sd_cn["pipe"], _sd_cn["canny"]
# -----------------------------
# SD 1.5 (prompt-only)
# -----------------------------
@spaces.GPU(duration=120)
def sd_generate(prompt, negative, steps, guidance, seed):
pipe = _get_sd15_txt2img()
# Reproducible generator on CUDA (available during @GPU call)
g = torch.Generator(device="cuda")
g = g.manual_seed(int(seed)) if int(seed) != 0 else g.manual_seed(random.randint(0, 2**31 - 1))
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
with torch.autocast(device_type="cuda", dtype=TORCH_DTYPE):
out = pipe(
prompt=str(prompt),
negative_prompt=(negative or ""),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=g
)
return out.images[0]
# -----------------------------
# Stylizer (SD1.5 + ControlNet canny + tile)
# -----------------------------
@spaces.GPU(duration=180)
def stylize_qr_sd15(prompt: str, negative: str, steps: int, guidance: float, seed: int,
canny_low: int, canny_high: int, border: int):
pipe, canny = _get_sd15_canny_tile()
# Fresh QR β†’ edges
qr_img = make_qr("http://www.mybirdfire.com", size=512, border=int(border))
edges = canny(qr_img, low_threshold=int(canny_low), high_threshold=int(canny_high))
# Control weights (canny, tile). Tune to taste.
cn_scales = [1.2, 0.6]
g = torch.Generator(device="cuda")
g = g.manual_seed(int(seed)) if int(seed) != 0 else g.manual_seed(random.randint(0, 2**31 - 1))
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
with torch.autocast(device_type="cuda", dtype=TORCH_DTYPE):
out = pipe(
prompt=str(prompt),
negative_prompt=(negative or NEG_DEFAULT),
image=[edges, qr_img], # txt2img ControlNet: control images
controlnet_conditioning_scale=cn_scales,
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=g
)
return out.images[0]
# -----------------------------
# UI (same layout as yours)
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## Stable Diffusion + QR Code + ControlNet (SD1.5) β€” ZeroGPU")
with gr.Tab("Stable Diffusion (prompt β†’ image)"):
prompt = gr.Textbox(label="Prompt", value="Sky, Moon, Bird, Blue, In the dark, Goddess, Sweet, Beautiful, Fantasy, Art, Anime")
negative = gr.Textbox(label="Negative Prompt", value="lowres, bad anatomy, worst quality")
steps = gr.Slider(10, 50, value=30, label="Steps", step=1)
cfg = gr.Slider(1, 12, value=7.0, label="Guidance Scale", step=0.1)
seed = gr.Number(value=0, label="Seed (0 = random)", precision=0)
out_sd = gr.Image(label="Generated Image")
gr.Button("Generate").click(sd_generate, [prompt, negative, steps, cfg, seed], out_sd)
with gr.Tab("QR Maker (mybirdfire)"):
url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
size = gr.Slider(256, 1024, value=512, step=64, label="Size (px)")
quiet = gr.Slider(0, 8, value=4, step=1, label="Border (quiet zone)")
out_qr = gr.Image(label="QR Code", type="pil")
gr.Button("Generate QR").click(make_qr, [url, size, quiet], out_qr)
with gr.Tab("QR Stylizer (SD1.5 canny + tile, Euler)"):
s_prompt = gr.Textbox(label="Style Prompt", value="Sky, Moon, Bird, Blue, In the dark, Goddess, Sweet, Beautiful, Fantasy, Art, Anime")
s_negative = gr.Textbox(label="Negative Prompt", value=NEG_DEFAULT)
s_steps = gr.Slider(10, 50, value=28, label="Steps", step=1)
s_cfg = gr.Slider(1, 12, value=7.0, label="CFG", step=0.1)
s_seed = gr.Number(value=1470713301, label="Seed", precision=0)
canny_l = gr.Slider(0, 255, value=80, step=1, label="Canny low")
canny_h = gr.Slider(0, 255, value=160, step=1, label="Canny high")
s_border = gr.Slider(2, 10, value=6, step=1, label="QR border")
out_styl = gr.Image(label="Stylized QR")
gr.Button("Stylize").click(
stylize_qr_sd15,
[s_prompt, s_negative, s_steps, s_cfg, s_seed, canny_l, canny_h, s_border],
out_styl
)
if __name__ == "__main__":
demo.queue(max_size=12).launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_api=True,
analytics_enabled=False
)