|
|
|
import spaces |
|
|
|
|
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import pipeline |
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler |
|
from diffusers.utils import load_image |
|
import os |
|
import huggingface_hub |
|
import config |
|
|
|
class ControlNetPipeline: |
|
def __init__(self): |
|
"""Initialize the ControlNet pipeline with lazy loading""" |
|
self.depth_estimator = None |
|
self.pipe = None |
|
self.controlnet = None |
|
self.is_initialized = False |
|
|
|
@spaces.GPU |
|
def initialize(self): |
|
"""Initialize the models with GPU acceleration""" |
|
if self.is_initialized: |
|
return |
|
|
|
|
|
self.depth_estimator = pipeline('depth-estimation') |
|
|
|
|
|
self.controlnet = ControlNetModel.from_pretrained( |
|
config.CONTROLNET_MODEL, |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
config.BASE_MODEL, |
|
controlnet=self.controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
|
|
|
|
try: |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
print("xformers not available, using default attention mechanism") |
|
|
|
self.pipe.enable_model_cpu_offload() |
|
self.is_initialized = True |
|
|
|
@spaces.GPU |
|
def process_image(self, image): |
|
"""Process the input image to generate depth map""" |
|
|
|
if not self.is_initialized: |
|
self.initialize() |
|
|
|
|
|
depth = self.depth_estimator(image)['depth'] |
|
depth_array = np.array(depth) |
|
depth_array = depth_array[:, :, None] |
|
depth_array = np.concatenate([depth_array, depth_array, depth_array], axis=2) |
|
depth_image = Image.fromarray(depth_array) |
|
|
|
return depth_image |
|
|
|
@spaces.GPU |
|
def generate(self, prompt, image, negative_prompt=None, guidance_scale=7.5, num_inference_steps=20): |
|
"""Generate an image using ControlNet with the provided prompt and input image""" |
|
|
|
if not self.is_initialized: |
|
self.initialize() |
|
|
|
|
|
depth_image = self.process_image(image) |
|
|
|
|
|
output = self.pipe( |
|
prompt=prompt, |
|
image=depth_image, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=float(guidance_scale), |
|
num_inference_steps=int(num_inference_steps) |
|
) |
|
|
|
return output.images[0] |
|
|