RAD / app.py
ashishtanwer's picture
Update app.py
b80257e
raw
history blame
2.93 kB
import torch
import os
import gradio as gr
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from IPython.display import display
from text_generation import Client, InferenceAPIClient
model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive
pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = None
#@markdown Can set random seed here for reproducibility.
g_cuda = torch.Generator(device='cuda')
seed = 52362 #@param {type:"number"}
g_cuda.manual_seed(seed)
#@title Run for generating images.
prompt = "photo of zwx dog in a bucket" #@param {type:"string"}
negative_prompt = "" #@param {type:"string"}
num_samples = 4 #@param {type:"number"}
guidance_scale = 7.5 #@param {type:"number"}
num_inference_steps = 24 #@param {type:"number"}
height = 512 #@param {type:"number"}
width = 512 #@param {type:"number"}
with autocast("cuda"), torch.inference_mode():
images = pipe(
prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=g_cuda
).images
for img in images:
display(img)
def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
with torch.autocast("cuda"), torch.inference_mode():
return pipe(
prompt, height=int(height), width=int(width),
negative_prompt=negative_prompt,
num_images_per_prompt=int(num_samples),
num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
generator=g_cuda
).images
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="photo of zwx dog in a bucket")
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
run = gr.Button(value="Generate")
with gr.Row():
num_samples = gr.Number(label="Number of Samples", value=4)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
with gr.Row():
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=24)
with gr.Column():
gallery = gr.Gallery()
run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch(debug=True)