import threading import cv2 import numpy as np import spaces import torch import torch.nn.functional as F # Add FLUX imports from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler, FluxControlNetModel, FluxControlNetPipeline) from einops import rearrange from PIL import Image from torchvision.transforms import ToPILImage import gradio as gr from .controlnet_union import ControlNetModel_Union from .pipeline_controlnet_union_sd_xl import \ StableDiffusionXLControlNetUnionPipeline from .render_utils import get_silhouette_image IMG_PIPE = None IMG_PIPE_LOCK = threading.Lock() # Add FLUX pipeline variables FLUX_PIPE = None FLUX_PIPE_LOCK = threading.Lock() FLUX_SUFFIX = None FLUX_NEGATIVE = None def lazy_get_flux_pipe(): """ Lazy load the FLUX pipeline with ControlNet for image generation. """ global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE if FLUX_PIPE is not None: return FLUX_PIPE gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.") with FLUX_PIPE_LOCK: if FLUX_PIPE is not None: return FLUX_PIPE FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture." FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field." base_model = 'black-forest-labs/FLUX.1-dev' controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0' controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16) FLUX_PIPE = FluxControlNetPipeline.from_pretrained( base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 ) # Use model CPU offload for better GPU utilization during inference FLUX_PIPE.enable_model_cpu_offload() return FLUX_PIPE def lazy_get_sdxl_pipe(): """ Lazy load the SDXL pipeline with ControlNet for image generation. """ global IMG_PIPE if IMG_PIPE is not None: return IMG_PIPE gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.") with IMG_PIPE_LOCK: if IMG_PIPE is not None: return IMG_PIPE eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler") # when test with other base model, you need to change the vae also. vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True) IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model, vae=vae, torch_dtype=torch.float16, scheduler=eulera_scheduler, ) # Move pipeline to CUDA device IMG_PIPE = IMG_PIPE.to("cuda") return IMG_PIPE def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: """ Generate image condition using SDXL model with ControlNet based on depth and normal images. :param depth_img: Depth image from the selected view. :param normal_img: Normal image (Camera Coordinate System) from the selected view. :param text_prompt: Text prompt for image generation. :param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. :param seed: Random seed for image generation. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). :param image_height: Height of the output image. :param image_width: Width of the output image. :param progress: Progress callback for Gradio. :return: Generated image condition (e.g., PIL Image). """ progress(0.1, desc="Loading SDXL pipeline...") pipeline = lazy_get_sdxl_pipe() progress(0.3, desc="SDXL pipeline loaded successfully.") positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground" negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' img_generation_resolution = 1024 # SDXL performs better at 1024x1024 image = pipeline(prompt=[positive_prompt]*1, image_list=[0, depth_img, 0, 0, normal_img, 0], negative_prompt=[negative_prompt]*1, generator=torch.Generator(device="cuda").manual_seed(seed), width=img_generation_resolution, height=img_generation_resolution, num_inference_steps=50, union_control=True, union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images progress=progress, ).images[0] progress(0.9, desc="Condition tensor generated successfully.") rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device) mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) # Apply edge refinement if enabled if edge_refinement: # Convert to CUDA device for edge refinement rgb_tensor_cuda = rgb_tensor.to("cuda") mask_tensor_cuda = mask_tensor.to("cuda") rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda) rgb_tensor = rgb_tensor_cuda.to(pipeline.device) background_tensor = torch.zeros_like(rgb_tensor) rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") rgb_tensor = rgb_tensor / 255. to_img = ToPILImage() condition_image = to_img(rgb_tensor.cpu()) progress(1, desc="Condition image generated successfully.") return condition_image def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: """ Generate image condition using FLUX model with ControlNet based on depth image only. Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth. :param depth_img: Depth image from the selected view. :param text_prompt: Text prompt for image generation. :param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. :param seed: Random seed for image generation. :param image_height: Height of the output image. :param image_width: Width of the output image. :param progress: Progress callback for Gradio. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). :return: Generated image condition (PIL Image). """ progress(0.1, desc="Loading FLUX pipeline...") pipeline = lazy_get_flux_pipe() progress(0.3, desc="FLUX pipeline loaded successfully.") # Enhanced prompt for better results positive_prompt = text_prompt + FLUX_SUFFIX negative_prompt = FLUX_NEGATIVE # Get image dimensions width, height = depth_img.size progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...") # Generate image using FLUX ControlNet with depth control # model_cpu_offload handles GPU loading automatically image = pipeline( prompt=positive_prompt, negative_prompt=negative_prompt, control_image=depth_img, width=width, height=height, controlnet_conditioning_scale=0.8, # Recommended for depth control_guidance_end=0.8, num_inference_steps=30, guidance_scale=3.5, generator=torch.Generator(device="cuda").manual_seed(seed), ).images[0] progress(0.9, desc="Applying mask and resizing...") # Convert to tensor and apply mask rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda") mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda") mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] # Resize to target dimensions rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) # Apply mask (blend with black background) background_tensor = torch.zeros_like(rgb_tensor) if edge_refinement: # replace edge with inner values rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor) rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) # Convert back to PIL Image rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") rgb_tensor = rgb_tensor / 255.0 to_img = ToPILImage() condition_image = to_img(rgb_tensor.cpu()) progress(1, desc="FLUX condition image generated successfully.") return condition_image def refine_image_edges(rgb_tensor, mask_tensor): """ Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries. Algorithm: 1. Erode mask to get eroded_mask 2. Double erode mask to get double_eroded_mask 3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask 4. Use circle_valid_mask to extract circle_rgb (clean edge values) 5. Dilate circle_rgb to cover the edge region 6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background :param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device :param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1] :return: refined_rgb_tensor """ # Convert tensors to numpy for OpenCV processing rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C) mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range # Create morphological kernel (3x3 as requested) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Step 1: Erode mask to get eroded_mask eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3) # Step 2: Double erode mask to get double_eroded_mask double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5) # Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np) # Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values) circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8) # Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly) dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8) # Step 6: Final composition # Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 # Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere refined_rgb_np = (rgb_np * double_eroded_mask_3c + dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8) # Convert refined RGB back to tensor refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda") return refined_rgb_tensor @spaces.GPU(duration=120) def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()): """ Generate the image condition based on the selected view's silhouette and text prompt. :param position_imgs: Position images from different views. :param normal_imgs: Normal images from different views. :param mask_imgs: Mask images from different views. :param w2c: World-to-camera transformation matrices. :param text_prompt: The text prompt for image generation. :param selected_view: The selected view for image generation. :param seed: Random seed for image generation. :param model: The image generation model type, supports "SDXL" and "FLUX". :param progress: Progress callback for Gradio. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True). :return: Generated condition image and status message. """ progress(0, desc="Handling geometry information...") silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view) depth_img = silhouette[0] normal_img = silhouette[1] mask = silhouette[2] try: if model == "SDXL": condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) return condition, "SDXL condition generated successfully." elif model == "FLUX": # FLUX only supports depth control, not normal condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) return condition, "FLUX condition generated successfully (depth-only control)." else: raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.") finally: torch.cuda.empty_cache()