scribble-sdxl / app.py
Deadmon's picture
Update app.py
69dc758 verified
raw
history blame
10.4 kB
#!/usr/bin/env python
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
import torchvision.transforms.functional as TF
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux import PidiNetDetector, HEDdetector
from diffusers.utils import load_image
from huggingface_hub import HfApi
from pathlib import Path
from PIL import Image, ImageOps
import torch
import numpy as np
import cv2
import os
import random
import spaces
from gradio_imageslider import ImageSlider
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
DESCRIPTION = '''# Scribble SDXL 🖋️🌄
sketch to image with SDXL, using [@xinsir](https://huggingface.co/xinsir) [scribble sdxl controlnet](https://huggingface.co/xinsir/controlnet-scribble-sdxl-1.0), [sdxl controlnet canny](https://huggingface.co/xinsir/controlnet-canny-sdxl-1.0)
'''
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
style_list = [
{
"name": "(No style)",
"prompt": "{prompt}",
"negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
},
{
"name": "Anime",
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
},
{
"name": "Digital Art",
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
"negative_prompt": "photo, photorealistic, realism, ugly",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
},
{
"name": "Pixel art",
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
},
{
"name": "Fantasy art",
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
},
{
"name": "Neonpunk",
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
},
{
"name": "Manga",
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
},
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + negative
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-scribble-sdxl-1.0",
torch_dtype=torch.float16
)
controlnet_canny = ControlNetModel.from_pretrained(
"xinsir/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16
)
# when test with other base model, you need to change the vae also.
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
scheduler=eulera_scheduler,
)
pipe.to(device)
# Load model.
pipe_canny = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet_canny,
vae=vae,
safety_checker=None,
torch_dtype=torch.float16,
scheduler=eulera_scheduler,
)
pipe_canny.to(device)
MAX_IMAGE_PIXELS = 100000000 # Adjust if needed.
def resize_image(image, max_pixels=MAX_IMAGE_PIXELS):
"""Resize an image to have at most max_pixels, maintaining aspect ratio."""
width, height = image.size
if width * height > max_pixels:
scale_factor = (max_pixels / (width * height)) ** 0.5
new_size = (int(width * scale_factor), int(height * scale_factor))
return image.resize(new_size, Image.ANTIALIAS)
return image
def process(image, prompt, style, detector_name):
# Convert image to RGB mode if it's not already
if image.mode != 'RGB':
image = image.convert('RGB')
image = resize_image(image)
width, height = image.size
prompt, negative_prompt = apply_style(style, prompt)
if detector_name == "hed":
image = HWC3(np.array(image, dtype=np.uint8))
with torch.no_grad():
detected_map = hed(image, scribble=True)
detected_map = HWC3(detected_map)
image = Image.fromarray(detected_map)
images = pipe(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
return images[0]
elif detector_name == "scribble":
image = HWC3(np.array(image, dtype=np.uint8))
with torch.no_grad():
detected_map = nms(image, 127, 3.0)
detected_map = HWC3(detected_map)
image = Image.fromarray(detected_map)
images = pipe(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
return images[0]
elif detector_name == "canny":
image = np.array(image, dtype=np.uint8)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = image
image = Image.fromarray(detected_map)
images = pipe_canny(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
return images[0]
block_css = (
code := """
#image_upload {
height: 100% !important;
}
#prompt_input {
height: 100% !important;
}
#select_style {
height: 100% !important;
}
#detect_method {
height: 100% !important;
}
#submit_button {
height: 100% !important;
}
"""
)
def create_demo():
"""Create Gradio demo."""
with gr.Blocks(css=block_css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", elem_id="prompt_input")
style = gr.Dropdown(STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Select style", elem_id="select_style")
detect_method = gr.Dropdown(choices=["scribble", "hed", "canny"], value="scribble", label="Select Detect Method", elem_id="detect_method")
submit_btn = gr.Button("Generate", elem_id="submit_button")
with gr.Column():
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=2, height="auto")
submit_btn.click(process, inputs=[input_image, prompt, style, detect_method], outputs=[gallery])
# Refresh button to apply the dark theme
refresh_btn = gr.Button("Refresh for Dark Theme")
refresh_btn.click(None, None, None, _js=js_func)
return demo
hed = HEDdetector.from_pretrained('lllyasviel/ControlNet')
if __name__ == "__main__":
demo = create_demo()
demo.launch(debug=True)