Spaces:
Running
on
Zero
Running
on
Zero
import gc, random | |
import gradio as gr | |
import torch, spaces | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
# ---- config ---- | |
MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
DTYPE = torch.float16 # ZeroGPU slice runs fp16 nicely | |
# lazy cache | |
_PIPE = None | |
def get_pipe(): | |
global _PIPE | |
if _PIPE is None: | |
pipe = StableDiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE, | |
safety_checker=None, | |
use_safetensors=True, | |
low_cpu_mem_usage=True, | |
) | |
# fast, stable scheduler | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++" | |
) | |
# memory savers (great for Spaces/ZeroGPU) | |
pipe.enable_attention_slicing() | |
pipe.enable_vae_slicing() | |
pipe.enable_model_cpu_offload() | |
_PIPE = pipe | |
return _PIPE | |
def snap8(x: int) -> int: | |
x = max(256, min(1024, int(x))) | |
return x - (x % 8) | |
def generate(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int): | |
pipe = get_pipe() | |
w, h = snap8(width), snap8(height) | |
# seed | |
if int(seed) < 0: | |
seed = random.randint(0, 2**31 - 1) | |
gen = torch.Generator(device="cuda").manual_seed(int(seed)) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
with torch.autocast(device_type="cuda", dtype=DTYPE): | |
out = pipe( | |
prompt=str(prompt), | |
negative_prompt=str(negative or ""), | |
num_inference_steps=int(steps), | |
guidance_scale=float(cfg), | |
width=w, height=h, | |
generator=gen, | |
) | |
return out.images[0] | |
# -------- UI -------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# π¨ Stable Diffusion 1.5 β ZeroGPU (public, minimal)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="a cozy reading nook, warm sunlight, cinematic lighting, highly detailed") | |
negative = gr.Textbox(label="Negative (optional)", value="lowres, blurry, watermark, text") | |
steps = gr.Slider(8, 40, value=28, step=1, label="Steps") | |
cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG") | |
width = gr.Slider(256, 1024, value=640, step=16, label="Width") | |
height = gr.Slider(256, 1024, value=640, step=16, label="Height") | |
seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)") | |
btn = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
out = gr.Image(label="Result", interactive=False) | |
btn.click(generate, [prompt, negative, steps, cfg, width, height, seed], out) | |
if __name__ == "__main__": | |
# Keep it plain so the Space builds cleanly | |
demo.launch() | |