Spaces:
Running
on
Zero
Running
on
Zero
import threading | |
import cv2 | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
# Add FLUX imports | |
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler, | |
FluxControlNetModel, FluxControlNetPipeline) | |
from einops import rearrange | |
from PIL import Image | |
from torchvision.transforms import ToPILImage | |
import gradio as gr | |
from .controlnet_union import ControlNetModel_Union | |
from .pipeline_controlnet_union_sd_xl import \ | |
StableDiffusionXLControlNetUnionPipeline | |
from .render_utils import get_silhouette_image | |
IMG_PIPE = None | |
IMG_PIPE_LOCK = threading.Lock() | |
# Add FLUX pipeline variables | |
FLUX_PIPE = None | |
FLUX_PIPE_LOCK = threading.Lock() | |
FLUX_SUFFIX = None | |
FLUX_NEGATIVE = None | |
def lazy_get_flux_pipe(): | |
""" | |
Lazy load the FLUX pipeline with ControlNet for image generation. | |
""" | |
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE | |
if FLUX_PIPE is not None: | |
return FLUX_PIPE | |
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.") | |
with FLUX_PIPE_LOCK: | |
if FLUX_PIPE is not None: | |
return FLUX_PIPE | |
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture." | |
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field." | |
base_model = 'black-forest-labs/FLUX.1-dev' | |
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0' | |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16) | |
FLUX_PIPE = FluxControlNetPipeline.from_pretrained( | |
base_model, | |
controlnet=controlnet, | |
torch_dtype=torch.bfloat16 | |
) | |
# Use model CPU offload for better GPU utilization during inference | |
FLUX_PIPE.enable_model_cpu_offload() | |
return FLUX_PIPE | |
def lazy_get_sdxl_pipe(): | |
""" | |
Lazy load the SDXL pipeline with ControlNet for image generation. | |
""" | |
global IMG_PIPE | |
if IMG_PIPE is not None: | |
return IMG_PIPE | |
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.") | |
with IMG_PIPE_LOCK: | |
if IMG_PIPE is not None: | |
return IMG_PIPE | |
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler") | |
# when test with other base model, you need to change the vae also. | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True) | |
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model, | |
vae=vae, | |
torch_dtype=torch.float16, | |
scheduler=eulera_scheduler, | |
) | |
# Move pipeline to CUDA device | |
IMG_PIPE = IMG_PIPE.to("cuda") | |
return IMG_PIPE | |
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: | |
""" | |
Generate image condition using SDXL model with ControlNet based on depth and normal images. | |
:param depth_img: Depth image from the selected view. | |
:param normal_img: Normal image (Camera Coordinate System) from the selected view. | |
:param text_prompt: Text prompt for image generation. | |
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. | |
:param seed: Random seed for image generation. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). | |
:param image_height: Height of the output image. | |
:param image_width: Width of the output image. | |
:param progress: Progress callback for Gradio. | |
:return: Generated image condition (e.g., PIL Image). | |
""" | |
progress(0.1, desc="Loading SDXL pipeline...") | |
pipeline = lazy_get_sdxl_pipe() | |
progress(0.3, desc="SDXL pipeline loaded successfully.") | |
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground" | |
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' | |
img_generation_resolution = 1024 # SDXL performs better at 1024x1024 | |
image = pipeline(prompt=[positive_prompt]*1, | |
image_list=[0, depth_img, 0, 0, normal_img, 0], | |
negative_prompt=[negative_prompt]*1, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
width=img_generation_resolution, | |
height=img_generation_resolution, | |
num_inference_steps=50, | |
union_control=True, | |
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images | |
progress=progress, | |
).images[0] | |
progress(0.9, desc="Condition tensor generated successfully.") | |
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device) | |
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape | |
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] | |
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
# Apply edge refinement if enabled | |
if edge_refinement: | |
# Convert to CUDA device for edge refinement | |
rgb_tensor_cuda = rgb_tensor.to("cuda") | |
mask_tensor_cuda = mask_tensor.to("cuda") | |
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda) | |
rgb_tensor = rgb_tensor_cuda.to(pipeline.device) | |
background_tensor = torch.zeros_like(rgb_tensor) | |
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) | |
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") | |
rgb_tensor = rgb_tensor / 255. | |
to_img = ToPILImage() | |
condition_image = to_img(rgb_tensor.cpu()) | |
progress(1, desc="Condition image generated successfully.") | |
return condition_image | |
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: | |
""" | |
Generate image condition using FLUX model with ControlNet based on depth image only. | |
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth. | |
:param depth_img: Depth image from the selected view. | |
:param text_prompt: Text prompt for image generation. | |
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. | |
:param seed: Random seed for image generation. | |
:param image_height: Height of the output image. | |
:param image_width: Width of the output image. | |
:param progress: Progress callback for Gradio. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). | |
:return: Generated image condition (PIL Image). | |
""" | |
progress(0.1, desc="Loading FLUX pipeline...") | |
pipeline = lazy_get_flux_pipe() | |
progress(0.3, desc="FLUX pipeline loaded successfully.") | |
# Enhanced prompt for better results | |
positive_prompt = text_prompt + FLUX_SUFFIX | |
negative_prompt = FLUX_NEGATIVE | |
# Get image dimensions | |
width, height = depth_img.size | |
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...") | |
# Generate image using FLUX ControlNet with depth control | |
# model_cpu_offload handles GPU loading automatically | |
image = pipeline( | |
prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
control_image=depth_img, | |
width=width, | |
height=height, | |
controlnet_conditioning_scale=0.8, # Recommended for depth | |
control_guidance_end=0.8, | |
num_inference_steps=30, | |
guidance_scale=3.5, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
).images[0] | |
progress(0.9, desc="Applying mask and resizing...") | |
# Convert to tensor and apply mask | |
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda") | |
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda") | |
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] | |
# Resize to target dimensions | |
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
# Apply mask (blend with black background) | |
background_tensor = torch.zeros_like(rgb_tensor) | |
if edge_refinement: | |
# replace edge with inner values | |
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor) | |
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) | |
# Convert back to PIL Image | |
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") | |
rgb_tensor = rgb_tensor / 255.0 | |
to_img = ToPILImage() | |
condition_image = to_img(rgb_tensor.cpu()) | |
progress(1, desc="FLUX condition image generated successfully.") | |
return condition_image | |
def refine_image_edges(rgb_tensor, mask_tensor): | |
""" | |
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries. | |
Algorithm: | |
1. Erode mask to get eroded_mask | |
2. Double erode mask to get double_eroded_mask | |
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask | |
4. Use circle_valid_mask to extract circle_rgb (clean edge values) | |
5. Dilate circle_rgb to cover the edge region | |
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background | |
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device | |
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1] | |
:return: refined_rgb_tensor | |
""" | |
# Convert tensors to numpy for OpenCV processing | |
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C) | |
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions | |
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range | |
# Create morphological kernel (3x3 as requested) | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
# Step 1: Erode mask to get eroded_mask | |
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3) | |
# Step 2: Double erode mask to get double_eroded_mask | |
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5) | |
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask | |
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np) | |
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values) | |
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 | |
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8) | |
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly) | |
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8) | |
# Step 6: Final composition | |
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background | |
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 | |
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere | |
refined_rgb_np = (rgb_np * double_eroded_mask_3c + | |
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8) | |
# Convert refined RGB back to tensor | |
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda") | |
return refined_rgb_tensor | |
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()): | |
""" | |
Generate the image condition based on the selected view's silhouette and text prompt. | |
:param position_imgs: Position images from different views. | |
:param normal_imgs: Normal images from different views. | |
:param mask_imgs: Mask images from different views. | |
:param w2c: World-to-camera transformation matrices. | |
:param text_prompt: The text prompt for image generation. | |
:param selected_view: The selected view for image generation. | |
:param seed: Random seed for image generation. | |
:param model: The image generation model type, supports "SDXL" and "FLUX". | |
:param progress: Progress callback for Gradio. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True). | |
:return: Generated condition image and status message. | |
""" | |
progress(0, desc="Handling geometry information...") | |
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view) | |
depth_img = silhouette[0] | |
normal_img = silhouette[1] | |
mask = silhouette[2] | |
try: | |
if model == "SDXL": | |
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) | |
return condition, "SDXL condition generated successfully." | |
elif model == "FLUX": | |
# FLUX only supports depth control, not normal | |
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) | |
return condition, "FLUX condition generated successfully (depth-only control)." | |
else: | |
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.") | |
finally: | |
torch.cuda.empty_cache() |