Wuerstchen / app.py
svjack's picture
Create app.py
7591849
raw
history blame
3.81 kB
import os
import sys
import gradio as gr
import numpy as np
import torch
import random
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen",
torch_dtype=torch.float32)
pipe.to(device)
pipe.safety_checker = None
'''
#### 9min a sample (2 cores)
caption = "Anthropomorphic cat dressed as a fire fighter"
images = pipe(
caption,
width=512,
height=512,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, #### length of 30
prior_guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps = 6, #### default num of 12, 6 favour
).images
'''
def process(prompt, num_samples, image_resolution, sample_steps, seed,):
from PIL import Image
with torch.no_grad():
if seed == -1:
seed = random.randint(0, 65535)
#control_image = Image.fromarray(detected_map)
# run inference
#generator = torch.Generator(device=device).manual_seed(seed)
H = image_resolution
W = image_resolution
images = []
for i in range(num_samples):
image = pipe(
prompt,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_inference_steps = sample_steps,
num_images_per_prompt=1,
height=H, width=W).images[0]
images.append(np.asarray(image))
results = images
return results
#return [255 - detected_map] + results
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Rapid Diffusion model from warp-ai/wuerstchen")
#gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n")
with gr.Row():
with gr.Column():
#input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png")
prompt = gr.Textbox(label="Prompt", value = "Anthropomorphic cat dressed as a fire fighter")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
#low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
#high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=6, step=1)
#scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
#eta = gr.Number(label="eta", value=0.0)
#a_prompt = gr.Textbox(label="Added Prompt", value='')
#n_prompt = gr.Textbox(label="Negative Prompt",
# value='低质量,模糊,混乱')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
#ips = [None, prompt, None, None, num_samples, image_resolution, sample_steps, None, seed, None, None, None]
ips = [prompt, num_samples, image_resolution, sample_steps, seed]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True)
block.launch(server_name='0.0.0.0')