Tanut
Fix
d6fcceb
raw
history blame
3.58 kB
import os, gc, random
import gradio as gr
import torch, spaces
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
# ---- config ----
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DTYPE = torch.float16
HF_TOKEN = os.getenv("HF_TOKEN") # optional (only needed for private models)
AUTH = {"token": HF_TOKEN} if HF_TOKEN else {}
# cache for lazy loading
_PIPE = {"sd": None}
def _get_pipe():
"""Lazy-load SD1.5 and enable memory savers for ZeroGPU."""
if _PIPE["sd"] is None:
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
safety_checker=None,
use_safetensors=True,
low_cpu_mem_usage=True,
**AUTH
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
_PIPE["sd"] = pipe
return _PIPE["sd"]
def _snap_dim(x: int) -> int:
# diffusers likes multiples of 8; stay within safe VRAM for ZeroGPU
x = max(256, min(1024, int(x)))
return x - (x % 8)
@spaces.GPU(duration=120) # allocate a ZeroGPU slice only during generation
def generate(prompt: str,
negative_prompt: str,
steps: int,
guidance_scale: float,
width: int,
height: int,
seed: int):
pipe = _get_pipe()
w = _snap_dim(width)
h = _snap_dim(height)
# seed handling (reproducible on CUDA)
g = torch.Generator(device="cuda")
if int(seed) == -1:
seed = random.randint(0, 2**31 - 1)
g = g.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
with torch.autocast(device_type="cuda", dtype=DTYPE):
out = pipe(
prompt=str(prompt),
negative_prompt=str(negative_prompt or ""),
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
width=w, height=h,
generator=g,
)
img: Image.Image = out.images[0]
return img, seed
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown("# 🧩 Stable Diffusion 1.5 (ZeroGPU)\nText prompt → image, lean & fast.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value="a cozy reading nook with warm sunlight, soft textures, cinematic lighting, highly detailed"
)
negative = gr.Textbox(
label="Negative prompt",
value="lowres, blurry, watermark, text, logo, nsfw"
)
steps = gr.Slider(4, 50, value=28, step=1, label="Steps")
cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG scale")
width = gr.Slider(256, 1024, value=640, step=16, label="Width")
height = gr.Slider(256, 1024, value=640, step=16, label="Height")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out_img = gr.Image(label="Result", interactive=False)
out_seed = gr.Number(label="Used seed", interactive=False)
btn.click(generate, [prompt, negative, steps, cfg, width, height, seed], [out_img, out_seed])
if __name__ == "__main__":
demo.queue(max_size=8).launch()