Spaces:
Running
on
Zero
Running
on
Zero
import os, re, gc, random | |
import numpy as np | |
from contextlib import nullcontext | |
from typing import Tuple | |
import gradio as gr | |
from PIL import Image, ImageFilter | |
import qrcode | |
from qrcode.constants import ERROR_CORRECT_H | |
import torch | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionControlNetImg2ImgPipeline, | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
) | |
import spaces # ZeroGPU decorator | |
# ========================================================= | |
# Auth (optional for private models) | |
# ========================================================= | |
hf_token = os.getenv("HF_TOKEN") | |
AUTH_KW = {"token": hf_token} if hf_token else {} | |
# ========================================================= | |
# Helpers (untouched logic) | |
# ========================================================= | |
def normalize_color(c): | |
if c is None: return "white" | |
if isinstance(c, (tuple, list)): | |
r, g, b = (int(max(0, min(255, round(float(x))))) for x in c[:3]); return (r, g, b) | |
if isinstance(c, str): | |
s = c.strip() | |
if s.startswith("#"): return s | |
m = re.match(r"rgba?\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)", s, re.IGNORECASE) | |
if m: | |
r = int(max(0, min(255, round(float(m.group(1)))))) | |
g = int(max(0, min(255, round(float(m.group(2)))))) | |
b = int(max(0, min(255, round(float(m.group(3)))))) | |
return (r, g, b) | |
return s | |
return "white" | |
def strengthen_qr_prompts(pos: str, neg: str) -> Tuple[str, str]: | |
# DON’T say “QR code” here – let ControlNet impose it | |
pos = (pos or "").strip() | |
neg = (neg or "").strip() | |
pos2 = f"{pos}, high contrast lighting, clean details, cohesive composition".strip(", ") | |
add_neg = "frame, border, ornate frame, watermark, text, numbers, checkerboard, mosaic, halftone, repeated pattern, glitch" | |
neg2 = (neg + (", " if neg else "") + add_neg).strip(", ").strip() | |
return pos2, neg2 | |
def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.6, feather: float = 1.0) -> Image.Image: | |
if strength <= 0: return stylized | |
q = qr_img.convert("L") | |
black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather))) | |
black = np.asarray(black_mask, dtype=np.float32) / 255.0 | |
white = 1.0 - black | |
s = np.asarray(stylized.convert("RGB"), dtype=np.float32) / 255.0 | |
s = s * (1.0 - float(strength) * black[..., None]) | |
s = s + (1.0 - s) * (float(strength) * 0.85 * white[..., None]) | |
s = np.clip(s, 0.0, 1.0) | |
return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB") | |
# ========================================================= | |
# Models & loading (ZeroGPU-friendly lazy load) | |
# ========================================================= | |
BASE_15 = "runwayml/stable-diffusion-v1-5" | |
QR_MONSTER_15 = "monster-labs/control_v1p_sd15_qrcode_monster" # v2 subfolder is handled by authors; base path is fine | |
BRIGHTNESS_15 = "latentcat/control_v1p_sd15_brightness" # optional helper | |
_sd = {"pipe": None} | |
_cn = {"pipe": None} | |
def _setup_scheduler(pipe): | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, | |
use_karras_sigmas=True, | |
algorithm_type="dpmsolver++" | |
) | |
def _enable_memory_savers(pipe): | |
# Good defaults for Spaces/ZeroGPU | |
pipe.enable_attention_slicing() | |
pipe.enable_vae_slicing() | |
pipe.enable_vae_tiling() | |
pipe.enable_model_cpu_offload() | |
def _load_sd_txt2img(): | |
if _sd["pipe"] is None: | |
pipe = StableDiffusionPipeline.from_pretrained( | |
BASE_15, | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
**AUTH_KW | |
) | |
_setup_scheduler(pipe) | |
_enable_memory_savers(pipe) | |
_sd["pipe"] = pipe | |
return _sd["pipe"] | |
def _load_cn_img2img(): | |
if _cn["pipe"] is None: | |
qrnet = ControlNetModel.from_pretrained( | |
QR_MONSTER_15, torch_dtype=torch.float16, use_safetensors=True, **AUTH_KW | |
) | |
bright = ControlNetModel.from_pretrained( | |
BRIGHTNESS_15, torch_dtype=torch.float16, use_safetensors=True, **AUTH_KW | |
) | |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
BASE_15, | |
controlnet=[qrnet, bright], | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
**AUTH_KW | |
) | |
_setup_scheduler(pipe) | |
_enable_memory_savers(pipe) | |
_cn["pipe"] = pipe | |
return _cn["pipe"] | |
# ========================================================= | |
# Generation utilities (use inside @spaces.GPU) | |
# ========================================================= | |
def sd_generate(prompt, negative, steps, guidance, seed, size=512): | |
pipe = _load_sd_txt2img() | |
# Reproducible generator — on GPU if available | |
gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
if int(seed) != 0: | |
gen = gen.manual_seed(int(seed)) | |
else: | |
gen = gen.manual_seed(random.randint(0, 2**31 - 1)) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
out = pipe( | |
prompt=prompt, | |
negative_prompt=negative or "", | |
num_inference_steps=int(steps), | |
guidance_scale=float(guidance), | |
width=int(size), height=int(size), | |
generator=gen | |
) | |
return out.images[0] | |
def make_qr(url="http://www.mybirdfire.com", size=512, border=10, back_color="#808080", blur_radius=0.0): | |
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border)) | |
qr.add_data(url.strip()); qr.make(fit=True) | |
bg = normalize_color(back_color) | |
img = qr.make_image(fill_color="black", back_color=bg).convert("RGB").resize((size, size), Image.NEAREST) | |
if blur_radius and blur_radius > 0: | |
img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius))) | |
return img | |
NEG_DEFAULT = "lowres, low contrast, blurry, jpeg artifacts, worst quality, bad anatomy, extra digits" | |
# ========================================================= | |
# Main two-stage generator (ZeroGPU-guarded) | |
# ========================================================= | |
# allocate GPU only while generating | |
def qr_art_two_stage( | |
prompt, negative, | |
base_steps, base_cfg, base_seed, | |
stylize_steps, stylize_cfg, stylize_seed, | |
size, url, border, back_color, | |
denoise, qr_weight, bright_weight, | |
qr_start, qr_end, bright_start, bright_end, | |
control_blur, repair_strength, feather_px | |
): | |
size = max(384, int(size) // 8 * 8) | |
# Stage A: base art (txt2img) | |
p_pos, p_neg = strengthen_qr_prompts(prompt, negative) | |
base_img = sd_generate(p_pos, p_neg, base_steps, base_cfg, base_seed, size=size) | |
# Stage B: img2img + ControlNet | |
qr_img = make_qr(url=url, size=size, border=border, back_color=back_color, blur_radius=control_blur) | |
pipe = _load_cn_img2img() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
if int(stylize_seed) != 0: | |
gen = gen.manual_seed(int(stylize_seed)) | |
else: | |
gen = gen.manual_seed(random.randint(0, 2**31 - 1)) | |
kwargs = dict( | |
prompt=p_pos, | |
negative_prompt=p_neg or NEG_DEFAULT, | |
image=base_img, # init image for img2img | |
control_image=[qr_img, qr_img], # Monster + Brightness | |
strength=float(denoise), # how much we allow change | |
num_inference_steps=int(stylize_steps), | |
guidance_scale=float(stylize_cfg), | |
generator=gen, | |
controlnet_conditioning_scale=[float(qr_weight), float(bright_weight)], | |
width=size, height=size, # (diffusers uses init image size; harmless here) | |
) | |
try: | |
out = pipe( | |
**kwargs, | |
control_guidance_start=[float(qr_start), float(bright_start)], | |
control_guidance_end=[float(qr_end), float(bright_end)], | |
) | |
except TypeError: | |
out = pipe( | |
**kwargs, | |
controlnet_start=[float(qr_start), float(bright_start)], | |
controlnet_end=[float(qr_end), float(bright_end)], | |
) | |
img = out.images[0] | |
# Optional post repair to push blacks/whites where modules demand | |
img = enforce_qr_contrast(img, qr_img, strength=float(repair_strength), feather=float(feather_px)) | |
return img, base_img, qr_img | |
# ========================================================= | |
# UI (Gradio Space) | |
# ========================================================= | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🧩 QR-Code Monster — Two-Stage (txt2img → img2img + ControlNet) — ZeroGPU") | |
with gr.Tab("Two-Stage QR Art"): | |
with gr.Row(): | |
with gr.Column(): | |
url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com") | |
prompt = gr.Textbox( | |
label="Style prompt (no 'QR code' here)", | |
value="baroque palace interior with intricate roots, cinematic, dramatic lighting, ultra detailed" | |
) | |
negative = gr.Textbox(label="Negative", value="") | |
size = gr.Slider(512, 1024, value=768, step=64, label="Canvas (px)") | |
gr.Markdown("**Stage A — Base art (txt2img)**") | |
base_steps = gr.Slider(10, 60, value=26, step=1, label="Base steps") | |
base_cfg = gr.Slider(1.0, 12.0, value=6.0, step=0.1, label="Base CFG") | |
base_seed = gr.Number(value=0, precision=0, label="Base seed (0=random)") | |
gr.Markdown("**Stage B — ControlNet img2img**") | |
stylize_steps = gr.Slider(10, 60, value=28, step=1, label="Stylize steps") | |
stylize_cfg = gr.Slider(1.0, 12.0, value=6.0, step=0.1, label="Stylize CFG") | |
stylize_seed = gr.Number(value=0, precision=0, label="Stylize seed (0=random)") | |
denoise = gr.Slider(0.1, 0.8, value=0.48, step=0.01, label="Denoising strength (keep composition lower)") | |
qr_weight = gr.Slider(0.5, 1.7, value=1.2, step=0.05, label="QR Monster weight") | |
bright_weight = gr.Slider(0.0, 1.0, value=0.20, step=0.05, label="Brightness weight") | |
qr_start = gr.Slider(0.0, 1.0, value=0.05, step=0.01, label="QR start") | |
qr_end = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="QR end") | |
bright_start = gr.Slider(0.0, 1.0, value=0.40, step=0.01, label="Brightness start") | |
bright_end = gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Brightness end") | |
border = gr.Slider(4, 20, value=12, step=1, label="QR border (quiet zone)") | |
back_color = gr.ColorPicker(value="#808080", label="QR background (mid-gray blends better)") | |
control_blur = gr.Slider(0.0, 3.0, value=1.2, step=0.1, label="Soften control (Gaussian blur radius)") | |
repair_strength = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Post repair strength") | |
feather_px = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)") | |
go = gr.Button("Generate QR Art", variant="primary") | |
with gr.Column(): | |
final_img = gr.Image(label="Final stylized QR") | |
base_img = gr.Image(label="Base art (Stage A)") | |
ctrl_img = gr.Image(label="Control image (QR used)") | |
go.click( | |
qr_art_two_stage, | |
inputs=[prompt, negative, | |
base_steps, base_cfg, base_seed, | |
stylize_steps, stylize_cfg, stylize_seed, | |
size, url, border, back_color, | |
denoise, qr_weight, bright_weight, | |
qr_start, qr_end, bright_start, bright_end, | |
control_blur, repair_strength, feather_px], | |
outputs=[final_img, base_img, ctrl_img] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |