Tanut
Test ZeroGPU
91f6a0e
raw
history blame
2.78 kB
import gc, random, os
import gradio as gr
import torch, spaces
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DTYPE = torch.float16
_PIPE = None
def get_pipe():
global _PIPE
if _PIPE is None:
# Build on CPU
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
safety_checker=None,
use_safetensors=True,
low_cpu_mem_usage=True,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="dpmsolver++"
)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
_PIPE = pipe
return _PIPE
def snap8(x: int) -> int:
x = max(256, min(1024, int(x)))
return x - (x % 8)
@spaces.GPU(duration=120)
def generate(prompt: str, negative: str, steps: int, cfg: float, width: int, height: int, seed: int):
pipe = get_pipe() # stays CPU/offloaded until now
w, h = snap8(width), snap8(height)
if int(seed) < 0:
seed = random.randint(0, 2**31 - 1)
gen = torch.Generator(device="cuda").manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
with torch.autocast(device_type="cuda", dtype=DTYPE):
out = pipe(
prompt=str(prompt),
negative_prompt=str(negative or ""),
num_inference_steps=int(steps),
guidance_scale=float(cfg),
width=w, height=h,
generator=gen,
)
return out.images[0]
with gr.Blocks() as demo:
gr.Markdown("# ZeroGPU + SD1.5 (minimal)")
with gr.Tab("Text → Image"):
prompt = gr.Textbox(label="Prompt", value="a cozy reading nook, warm sunlight, cinematic lighting, highly detailed")
negative = gr.Textbox(label="Negative (optional)", value="lowres, blurry, watermark, text")
steps = gr.Slider(8, 40, value=28, step=1, label="Steps")
cfg = gr.Slider(1.0, 12.0, value=7.0, step=0.5, label="CFG")
width = gr.Slider(256, 1024, value=640, step=16, label="Width")
height = gr.Slider(256, 1024, value=640, step=16, label="Height")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
out_img = gr.Image(label="Image", interactive=False)
gr.Button("Generate").click(generate, [prompt, negative, steps, cfg, width, height, seed], out_img)
if __name__ == "__main__":
# On Spaces: keep it simple; don’t pass odd kwargs.
# If you see “localhost is not accessible”, add share=True.
demo.queue(max_size=12).launch(share=True)