ilanser's picture
Create app.py
7cdd9b7
raw
history blame
1.45 kB
import gradio as gr
from PIL import Image
import base64
import io
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
def predict(sketch, description):
# Convert sketch to PIL image
sketch_pil = Image.fromarray(sketch)
hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
image = hed(sketch_pil, scribble=True)
model_id = "runwayml/stable-diffusion-v1-5"
controlnet_id = "lllyasviel/sd-controlnet-scribble"
# Load ControlNet model
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
# Create pipeline with ControlNet model
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.float16)
# Use improved scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# Enable smart CPU offloading and memory efficient attention
# pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
result = pipe(description, image, num_inference_steps=20).images[0]
return result
# Define sketchpad with custom size and stroke width
sketchpad = gr.Sketchpad(shape=(1024, 1024), brush_radius=5)
iface = gr.Interface(fn=predict, inputs=[sketchpad, "text"], outputs="image", live=False)
iface.launch(share=True)