Spaces:
Sleeping
Sleeping
import os, gc, random, re | |
import gradio as gr | |
import torch, spaces | |
from PIL import Image, ImageFilter | |
import numpy as np | |
import qrcode | |
from qrcode.constants import ERROR_CORRECT_H | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
StableDiffusionControlNetImg2ImgPipeline, # for Hi-Res Fix | |
ControlNetModel, | |
DPMSolverMultistepScheduler, | |
) | |
# Quiet matplotlib cache warning on Spaces | |
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl") | |
# ---- base models for the two tabs ---- | |
BASE_MODELS = { | |
"stable-diffusion-v1-5": "runwayml/stable-diffusion-v1-5", | |
"dream": "Lykon/dreamshaper-8", | |
} | |
# ControlNet (QR Monster v2 for SD15) | |
CN_QRMON = "monster-labs/control_v1p_sd15_qrcode_monster" | |
DTYPE = torch.float16 | |
# ---------- helpers ---------- | |
def snap8(x: int) -> int: | |
x = max(256, min(1024, int(x))) | |
return x - (x % 8) | |
def normalize_color(c): | |
if c is None: return "white" | |
if isinstance(c, (tuple, list)): | |
r, g, b = (int(max(0, min(255, round(float(x))))) for x in c[:3]); return (r, g, b) | |
if isinstance(c, str): | |
s = c.strip() | |
if s.startswith("#"): return s | |
m = re.match(r"rgba?\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)", s, re.IGNORECASE) | |
if m: | |
r = int(max(0, min(255, round(float(m.group(1)))))) | |
g = int(max(0, min(255, round(float(m.group(2)))))) | |
b = int(max(0, min(255, round(float(m.group(3)))))) | |
return (r, g, b) | |
return s | |
return "white" | |
def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0): | |
""" | |
IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR (no blur). | |
""" | |
qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border)) | |
qr.add_data(url.strip()); qr.make(fit=True) | |
img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB") | |
img = img.resize((int(size), int(size)), Image.NEAREST) | |
if blur_radius and blur_radius > 0: | |
img = img.filter(ImageFilter.GaussianBlur(radius=float(blur_radius))) | |
return img | |
def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.0, feather: float = 1.0) -> Image.Image: | |
"""Optional gentle repair. Default OFF for Method 1.""" | |
if strength <= 0: return stylized | |
q = qr_img.convert("L") | |
black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather))) | |
black = np.asarray(black_mask, dtype=np.float32) / 255.0 | |
white = 1.0 - black | |
s = np.asarray(stylized.convert("RGB"), dtype=np.float32) / 255.0 | |
s = s * (1.0 - float(strength) * black[..., None]) | |
s = s + (1.0 - s) * (float(strength) * 0.85 * white[..., None]) | |
s = np.clip(s, 0.0, 1.0) | |
return Image.fromarray((s * 255.0).astype(np.uint8), mode="RGB") | |
# ---------- lazy pipelines (CPU-offloaded for ZeroGPU) ---------- | |
_CN = None # shared ControlNet QR Monster | |
_CN_TXT2IMG = {} # per-base-model txt2img pipes | |
_CN_IMG2IMG = {} # per-base-model img2img pipes | |
def _base_scheduler_for(pipe): | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++" | |
) | |
pipe.enable_attention_slicing() | |
pipe.enable_vae_slicing() | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def get_cn(): | |
global _CN | |
if _CN is None: | |
_CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True) | |
return _CN | |
def get_qrmon_txt2img_pipe(model_id: str): | |
if model_id not in _CN_TXT2IMG: | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
model_id, | |
controlnet=get_cn(), | |
torch_dtype=DTYPE, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
) | |
_CN_TXT2IMG[model_id] = _base_scheduler_for(pipe) | |
return _CN_TXT2IMG[model_id] | |
def get_qrmon_img2img_pipe(model_id: str): | |
if model_id not in _CN_IMG2IMG: | |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
model_id, | |
controlnet=get_cn(), | |
torch_dtype=DTYPE, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
) | |
_CN_IMG2IMG[model_id] = _base_scheduler_for(pipe) | |
return _CN_IMG2IMG[model_id] | |
# -------- Method 1: QR control model in text-to-image (+ optional Hi-Res Fix) -------- | |
def _qr_txt2img_core(model_id: str, | |
url: str, style_prompt: str, negative: str, | |
steps: int, cfg: float, size: int, border: int, | |
qr_weight: float, seed: int, | |
use_hires: bool, hires_upscale: float, hires_strength: float, | |
repair_strength: float, feather: float): | |
s = snap8(size) | |
# Control image: crisp black-on-white QR | |
qr_img = make_qr(url=url, size=s, border=int(border), back_color="#FFFFFF", blur_radius=0.0) | |
# Seed / generator | |
if int(seed) < 0: | |
seed = random.randint(0, 2**31 - 1) | |
gen = torch.Generator(device="cuda").manual_seed(int(seed)) | |
# --- Stage A: txt2img with ControlNet | |
pipe = get_qrmon_txt2img_pipe(model_id) | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
gc.collect() | |
with torch.autocast(device_type="cuda", dtype=DTYPE): | |
out = pipe( | |
prompt=str(style_prompt), | |
negative_prompt=str(negative or ""), | |
image=qr_img, # control image for txt2img | |
controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well | |
control_guidance_start=0.0, | |
control_guidance_end=1.0, | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
width=s, height=s, | |
generator=gen, | |
) | |
lowres = out.images[0] | |
lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather)) | |
# --- Optional Stage B: Hi-Res Fix (img2img with same QR) | |
final = lowres | |
if use_hires: | |
up = max(1.0, min(2.0, float(hires_upscale))) | |
W = snap8(int(s * up)); H = W | |
pipe2 = get_qrmon_img2img_pipe(model_id) | |
if torch.cuda.is_available(): torch.cuda.empty_cache() | |
gc.collect() | |
with torch.autocast(device_type="cuda", dtype=DTYPE): | |
out2 = pipe2( | |
prompt=str(style_prompt), | |
negative_prompt=str(negative or ""), | |
image=lowres, # init image | |
control_image=qr_img, # same QR | |
strength=float(hires_strength), # ~0.7 like "Hires Fix" | |
controlnet_conditioning_scale=float(qr_weight), | |
control_guidance_start=0.0, | |
control_guidance_end=1.0, | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
width=W, height=H, | |
generator=gen, | |
) | |
final = out2.images[0] | |
final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather)) | |
return final, lowres, qr_img | |
# Wrappers for each tab (so Gradio can bind without passing the model id) | |
def qr_txt2img_anything(*args): | |
return _qr_txt2img_core(BASE_MODELS["stable-diffusion-v1-5"], *args) | |
def qr_txt2img_dream(*args): | |
return _qr_txt2img_core(BASE_MODELS["dream"], *args) | |
# ---------- UI ---------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)") | |
# ---- Tab 1: stable-diffusion-v1-5 (anime/illustration) ---- | |
with gr.Tab("stable-diffusion-v1-5"): | |
url1 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com") | |
s_prompt1 = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition") | |
s_negative1= gr.Textbox(label="Negative prompt", value="ugly, low quality, blurry, nsfw, watermark, text, low contrast, deformed, extra digits") | |
size1 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)") | |
steps1 = gr.Slider(10, 50, value=20, step=1, label="Steps") | |
cfg1 = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG") | |
border1 = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)") | |
qr_w1 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight") | |
seed1 = gr.Number(value=-1, precision=0, label="Seed (-1 random)") | |
use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)") | |
hires_up1 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)") | |
hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength") | |
repair1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)") | |
feather1 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)") | |
final_img1 = gr.Image(label="Final (or Hi-Res) image") | |
low_img1 = gr.Image(label="Low-res (Stage A) preview") | |
ctrl_img1 = gr.Image(label="Control QR used") | |
gr.Button("Generate with stable-diffusion-v1-5").click( | |
qr_txt2img_anything, | |
[url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1, | |
use_hires1, hires_up1, hires_str1, repair1, feather1], | |
[final_img1, low_img1, ctrl_img1] | |
) | |
# ---- Tab 2: DreamShaper (general art/painterly) ---- | |
with gr.Tab("DreamShaper 8"): | |
url2 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com") | |
s_prompt2 = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic") | |
s_negative2= gr.Textbox(label="Negative prompt", value="lowres, low contrast, blurry, jpeg artifacts, watermark, text, bad anatomy") | |
size2 = gr.Slider(384, 1024, value=512, step=64, label="Canvas (px)") | |
steps2 = gr.Slider(10, 50, value=24, step=1, label="Steps") | |
cfg2 = gr.Slider(1.0, 12.0, value=6.8, step=0.1, label="CFG") | |
border2 = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)") | |
qr_w2 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight") | |
seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)") | |
use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)") | |
hires_up2 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)") | |
hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength") | |
repair2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)") | |
feather2 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)") | |
final_img2 = gr.Image(label="Final (or Hi-Res) image") | |
low_img2 = gr.Image(label="Low-res (Stage A) preview") | |
ctrl_img2 = gr.Image(label="Control QR used") | |
gr.Button("Generate with DreamShaper 8").click( | |
qr_txt2img_dream, | |
[url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2, | |
use_hires2, hires_up2, hires_str2, repair2, feather2], | |
[final_img2, low_img2, ctrl_img2] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=12).launch() | |