File size: 3,580 Bytes
19c2aa1
1fc8d06
fc29d6e
19c2aa1
fc29d6e
184daa2
fc29d6e
 
 
 
19c2aa1
 
fc29d6e
 
19c2aa1
fc29d6e
 
 
6ae079b
fc29d6e
 
6ae079b
 
 
19c2aa1
6ae079b
fc29d6e
 
6ae079b
19c2aa1
 
 
fc29d6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19c2aa1
fc29d6e
 
 
6ae079b
 
 
 
 
fc29d6e
19c2aa1
 
fc29d6e
19c2aa1
fc29d6e
 
 
19c2aa1
fc29d6e
 
e20d060
fc29d6e
56a99b7
fc29d6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3315876
56a99b7
d6fcceb
6f0e213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os, gc, random
import gradio as gr
import torch, spaces
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# ---- config ----
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DTYPE = torch.float16
HF_TOKEN = os.getenv("HF_TOKEN")  # optional (only needed for private models)
AUTH = {"token": HF_TOKEN} if HF_TOKEN else {}

# cache for lazy loading
_PIPE = {"sd": None}

def _get_pipe():
    """Lazy-load SD1.5 and enable memory savers for ZeroGPU."""
    if _PIPE["sd"] is None:
        pipe = StableDiffusionPipeline.from_pretrained(
            MODEL_ID,
            torch_dtype=DTYPE,
            safety_checker=None,
            use_safetensors=True,
            low_cpu_mem_usage=True,
            **AUTH
        )
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
        )
        pipe.enable_attention_slicing()
        pipe.enable_vae_slicing()
        pipe.enable_model_cpu_offload()
        _PIPE["sd"] = pipe
    return _PIPE["sd"]

def _snap_dim(x: int) -> int:
    # diffusers likes multiples of 8; stay within safe VRAM for ZeroGPU
    x = max(256, min(1024, int(x)))
    return x - (x % 8)

@spaces.GPU(duration=120)  # allocate a ZeroGPU slice only during generation
def generate(prompt: str,
             negative_prompt: str,
             steps: int,
             guidance_scale: float,
             width: int,
             height: int,
             seed: int):
    pipe = _get_pipe()

    w = _snap_dim(width)
    h = _snap_dim(height)

    # seed handling (reproducible on CUDA)
    g = torch.Generator(device="cuda")
    if int(seed) == -1:
        seed = random.randint(0, 2**31 - 1)
    g = g.manual_seed(int(seed))

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    with torch.autocast(device_type="cuda", dtype=DTYPE):
        out = pipe(
            prompt=str(prompt),
            negative_prompt=str(negative_prompt or ""),
            num_inference_steps=int(steps),
            guidance_scale=float(guidance_scale),
            width=w, height=h,
            generator=g,
        )
    img: Image.Image = out.images[0]
    return img, seed

# ---------- UI ----------
with gr.Blocks() as demo:
    gr.Markdown("# 🧩 Stable Diffusion 1.5 (ZeroGPU)\nText prompt → image, lean & fast.")

    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Prompt",
                value="a cozy reading nook with warm sunlight, soft textures, cinematic lighting, highly detailed"
            )
            negative = gr.Textbox(
                label="Negative prompt",
                value="lowres, blurry, watermark, text, logo, nsfw"
            )
            steps = gr.Slider(4, 50, value=28, step=1, label="Steps")
            cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG scale")
            width = gr.Slider(256, 1024, value=640, step=16, label="Width")
            height = gr.Slider(256, 1024, value=640, step=16, label="Height")
            seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")

            btn = gr.Button("Generate", variant="primary")
        with gr.Column():
            out_img = gr.Image(label="Result", interactive=False)
            out_seed = gr.Number(label="Used seed", interactive=False)

    btn.click(generate, [prompt, negative, steps, cfg, width, height, seed], [out_img, out_seed])

if __name__ == "__main__":
    demo.queue(max_size=8).launch()