# !pip install diffusers import torch from diffusers import DDIMPipeline, DDPMPipeline, PNDMPipeline from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler from diffusers import UNetUnconditionalModel import gradio as gr import PIL.Image import numpy as np import random model_id = "google/ddpm-celebahq-256" model = UNetUnconditionalModel.from_pretrained(model_id, subfolder="unet") # load model and scheduler ddpm_scheduler = DDPMScheduler.from_config(model_id, subfolder="scheduler") ddpm_pipeline = DDPMPipeline(unet=model, scheduler=ddpm_scheduler) ddim_scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") ddim_pipeline = DDIMPipeline(unet=model, scheduler=ddim_scheduler) pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler") pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler) # run pipeline in inference (sample random noise and denoise) def predict(steps=100, seed=42,scheduler="ddim"): torch.cuda.empty_cache() generator = torch.manual_seed(seed) if(scheduler == "ddim"): image = ddim_pipeline(generator=generator, num_inference_steps=steps)["sample"] elif(scheduler == "ddpm"): image = ddpm_pipeline(generator=generator)["sample"] elif(scheduler == "pndm"): image = pndm_pipeline(generator=generator, num_inference_steps=steps)["sample"] image_processed = image.cpu().permute(0, 2, 3, 1) if scheduler == "pndm": image_processed = (image_processed + 1.0) / 2 image_processed = torch.clamp(image_processed, 0.0, 1.0) image_processed = image_processed * 255 else: image_processed = (image_processed + 1.0) * 127.5 image_processed = image_processed.detach().numpy().astype(np.uint8) return(PIL.Image.fromarray(image_processed[0])) random_seed = random.randint(0, 2147483647) gr.Interface( predict, inputs=[ gr.inputs.Slider(1, 100, label='Inference Steps', default=20, step=1), gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed), gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler") ], outputs=gr.Image(shape=[256,256], type="pil"), ).launch()