Spaces:
Running
on
Zero
Running
on
Zero
import os, gc, random, re | |
import gradio as gr | |
import torch, spaces | |
from PIL import Image, ImageFilter | |
import numpy as np | |
import qrcode | |
from qrcode.constants import ERROR_CORRECT_H | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionControlNetPipeline, | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
) | |
# Optional: silence matplotlib cache warning in Spaces | |
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl") | |
MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster" # v2 on the repo | |
DTYPE = torch.float16 | |
# ---------- helpers ---------- | |
def snap8(x: int) -> int: | |
x = max(256, min(1024, int(x))) | |
return x - (x % 8) | |
def normalize_color(c): | |
if c is None: return "white" | |
if isinstance(c, (tuple, list)): | |
r, g, b = (int(max(0, min(255, round(float(x))))) for x in c[:3]); return (r, g, b) | |
if isinstance(c, str): | |
s = c.strip() | |
if s.startswith("#"): return s | |
m = re.match(r"rgba?\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)", s, re.IGNORECASE) | |
if m: | |
r = int(max(0, min(255, round(float(m.group(1)))))) | |
g = int(max(0, min(255, round(float(m.group(2)))))) | |
b = int(max(0, min(255, round(float(m.group(3)))))) | |
return (r, g, b) | |
return s | |
return "white" | |
def make_qr(url="http://www.mybirdfire.com", size=768, border=12, back_color="#808080", blur_radius=1.2): | |
# Mid-gray background improves blending & scan rate with QR-Monster v2. :contentReference[oaicite:1]{index=1} | |
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border)) | |
qr.add_data(url.strip()); qr.make(fit=True) | |
img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB") | |
img = img.resize((int(size), int(size)), Image.NEAREST) | |
if blur_radius and blur_radius > 0: | |
img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius))) | |
return img | |
def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.6, feather: float = 1.0) -> Image.Image: | |
"""Gently push ControlNet-required blacks/whites for scannability (simple post 'repair').""" | |
if strength <= 0: return stylized | |
q = qr_img.convert("L") | |
black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather))) | |
black = np.asarray(black_mask, dtype=np.float32) / 255.0 | |
white = 1.0 - black | |
s = np.asarray(stylized.convert("RGB"), dtype=np.float32) / 255.0 | |
s = s * (1.0 - float(strength) * black[..., None]) # deepen blacks | |
s = s + (1.0 - s) * (float(strength) * 0.85 * white[..., None]) # lift whites | |
s = np.clip(s, 0.0, 1.0) | |
return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB") | |
# ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ---------- | |
_SD = None | |
_CN = None | |
def get_sd_pipe(): | |
global _SD | |
if _SD is None: | |
pipe = StableDiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++" | |
) | |
pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload() | |
_SD = pipe | |
return _SD | |
def get_qrmon_pipe(): | |
global _CN | |
if _CN is None: | |
cn = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
MODEL_ID, | |
controlnet=cn, | |
torch_dtype=DTYPE, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++" | |
) | |
pipe.enable_attention_slicing(); pipe.enable_vae_slicing(); pipe.enable_model_cpu_offload() | |
_CN = pipe | |
return _CN | |
# ---------- ZeroGPU tasks ---------- | |
def txt2img(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int): | |
pipe = get_sd_pipe() | |
w, h = snap8(width), snap8(height) | |
if int(seed) < 0: | |
seed = random.randint(0, 2**31 - 1) | |
gen = torch.Generator(device="cuda").manual_seed(int(seed)) | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
gc.collect() | |
with torch.autocast(device_type="cuda", dtype=DTYPE): | |
out = pipe( | |
prompt=str(prompt), | |
negative_prompt=str(negative or ""), | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
width=w, height=h, | |
generator=gen, | |
) | |
return out.images[0] | |
def qr_stylize(url: str, style_prompt: str, negative: str, steps: int, cfg: float, | |
size: int, border: int, back_color: str, blur: float, | |
qr_weight: float, repair_strength: float, feather: float, seed: int): | |
pipe = get_qrmon_pipe() | |
s = snap8(size) | |
qr_img = make_qr(url=url, size=s, border=int(border), back_color=back_color, blur_radius=float(blur)) | |
if int(seed) < 0: | |
seed = random.randint(0, 2**31 - 1) | |
gen = torch.Generator(device="cuda").manual_seed(int(seed)) | |
# Tip from the article: don't stuff "QR code" into the prompt; let ControlNet shape it. :contentReference[oaicite:2]{index=2} | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
gc.collect() | |
with torch.autocast(device_type="cuda", dtype=DTYPE): | |
out = pipe( | |
prompt=str(style_prompt), | |
negative_prompt=str(negative or ""), | |
control_image=qr_img, | |
controlnet_conditioning_scale=float(qr_weight), | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
width=s, height=s, | |
generator=gen, | |
) | |
img = out.images[0] | |
img = enforce_qr_contrast(img, qr_img, strength=float(repair_strength), feather=float(feather)) | |
return img, qr_img | |
# ---------- UI ---------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# ZeroGPU Stable Diffusion + AI QR Codes (Monster v2)") | |
with gr.Tab("Text → Image"): | |
prompt = gr.Textbox(label="Prompt", value="a cozy reading nook, warm sunlight, cinematic lighting, highly detailed") | |
negative = gr.Textbox(label="Negative (optional)", value="lowres, blurry, watermark, text") | |
steps = gr.Slider(8, 40, value=28, step=1, label="Steps") | |
cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG") | |
width = gr.Slider(256, 1024, value=640, step=16, label="Width") | |
height = gr.Slider(256, 1024, value=640, step=16, label="Height") | |
seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)") | |
out_img = gr.Image(label="Image", interactive=False) | |
gr.Button("Generate").click(txt2img, [prompt, negative, steps, cfg, width, height, seed], out_img) | |
with gr.Tab("QR Code Stylizer (ControlNet Monster)"): | |
url = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com") | |
s_prompt = gr.Textbox(label="Style prompt (no 'QR code' needed)", value="baroque palace interior, intricate roots, dramatic lighting, ultra detailed") | |
s_negative= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, worst quality, watermark, text") | |
size = gr.Slider(384, 1024, value=768, step=64, label="Canvas (px)") | |
steps2 = gr.Slider(10, 50, value=28, step=1, label="Steps") | |
cfg2 = gr.Slider(1.0, 12.0, value=6.5, step=0.1, label="CFG") | |
border = gr.Slider(4, 20, value=12, step=1, label="QR border (quiet zone)") | |
back_col = gr.ColorPicker(value="#808080", label="QR background") | |
blur = gr.Slider(0.0, 3.0, value=1.2, step=0.1, label="Soften control (blur)") | |
qr_w = gr.Slider(0.6, 1.6, value=1.2, step=0.05, label="QR control weight") | |
repair = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Post repair strength") | |
feather = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)") | |
seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)") | |
final_img = gr.Image(label="Final stylized QR") | |
ctrl_img = gr.Image(label="Control QR used") | |
gr.Button("Stylize QR").click( | |
qr_stylize, | |
[url, s_prompt, s_negative, steps2, cfg2, size, border, back_col, blur, qr_w, repair, feather, seed2], | |
[final_img, ctrl_img] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=12).launch() | |