SeqTex / utils /image_generation.py
yuanze1024's picture
init space
1d5bb62
raw
history blame
14.7 kB
import threading
import cv2
import numpy as np
import spaces
import torch
import torch.nn.functional as F
# Add FLUX imports
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler,
FluxControlNetModel, FluxControlNetPipeline)
from einops import rearrange
from PIL import Image
from torchvision.transforms import ToPILImage
import gradio as gr
from .controlnet_union import ControlNetModel_Union
from .pipeline_controlnet_union_sd_xl import \
StableDiffusionXLControlNetUnionPipeline
from .render_utils import get_silhouette_image
IMG_PIPE = None
IMG_PIPE_LOCK = threading.Lock()
# Add FLUX pipeline variables
FLUX_PIPE = None
FLUX_PIPE_LOCK = threading.Lock()
FLUX_SUFFIX = None
FLUX_NEGATIVE = None
def lazy_get_flux_pipe():
"""
Lazy load the FLUX pipeline with ControlNet for image generation.
"""
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE
if FLUX_PIPE is not None:
return FLUX_PIPE
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.")
with FLUX_PIPE_LOCK:
if FLUX_PIPE is not None:
return FLUX_PIPE
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture."
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field."
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch.bfloat16
)
# Use model CPU offload for better GPU utilization during inference
FLUX_PIPE.enable_model_cpu_offload()
return FLUX_PIPE
def lazy_get_sdxl_pipe():
"""
Lazy load the SDXL pipeline with ControlNet for image generation.
"""
global IMG_PIPE
if IMG_PIPE is not None:
return IMG_PIPE
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.")
with IMG_PIPE_LOCK:
if IMG_PIPE is not None:
return IMG_PIPE
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
# when test with other base model, you need to change the vae also.
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True)
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model,
vae=vae,
torch_dtype=torch.float16,
scheduler=eulera_scheduler,
)
# Move pipeline to CUDA device
IMG_PIPE = IMG_PIPE.to("cuda")
return IMG_PIPE
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
"""
Generate image condition using SDXL model with ControlNet based on depth and normal images.
:param depth_img: Depth image from the selected view.
:param normal_img: Normal image (Camera Coordinate System) from the selected view.
:param text_prompt: Text prompt for image generation.
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
:param seed: Random seed for image generation.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
:param image_height: Height of the output image.
:param image_width: Width of the output image.
:param progress: Progress callback for Gradio.
:return: Generated image condition (e.g., PIL Image).
"""
progress(0.1, desc="Loading SDXL pipeline...")
pipeline = lazy_get_sdxl_pipe()
progress(0.3, desc="SDXL pipeline loaded successfully.")
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft lightning, uniform color, foreground"
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
image = pipeline(prompt=[positive_prompt]*1,
image_list=[0, depth_img, 0, 0, normal_img, 0],
negative_prompt=[negative_prompt]*1,
generator=torch.Generator(device="cuda").manual_seed(seed),
width=img_generation_resolution,
height=img_generation_resolution,
num_inference_steps=50,
union_control=True,
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images
progress=progress,
).images[0]
progress(0.9, desc="Condition tensor generated successfully.")
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device)
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
# Apply edge refinement if enabled
if edge_refinement:
# Convert to CUDA device for edge refinement
rgb_tensor_cuda = rgb_tensor.to("cuda")
mask_tensor_cuda = mask_tensor.to("cuda")
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda)
rgb_tensor = rgb_tensor_cuda.to(pipeline.device)
background_tensor = torch.zeros_like(rgb_tensor)
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
rgb_tensor = rgb_tensor / 255.
to_img = ToPILImage()
condition_image = to_img(rgb_tensor.cpu())
progress(1, desc="Condition image generated successfully.")
return condition_image
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
"""
Generate image condition using FLUX model with ControlNet based on depth image only.
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth.
:param depth_img: Depth image from the selected view.
:param text_prompt: Text prompt for image generation.
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
:param seed: Random seed for image generation.
:param image_height: Height of the output image.
:param image_width: Width of the output image.
:param progress: Progress callback for Gradio.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
:return: Generated image condition (PIL Image).
"""
progress(0.1, desc="Loading FLUX pipeline...")
pipeline = lazy_get_flux_pipe()
progress(0.3, desc="FLUX pipeline loaded successfully.")
# Enhanced prompt for better results
positive_prompt = text_prompt + FLUX_SUFFIX
negative_prompt = FLUX_NEGATIVE
# Get image dimensions
width, height = depth_img.size
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...")
# Generate image using FLUX ControlNet with depth control
# model_cpu_offload handles GPU loading automatically
image = pipeline(
prompt=positive_prompt,
negative_prompt=negative_prompt,
control_image=depth_img,
width=width,
height=height,
controlnet_conditioning_scale=0.8, # Recommended for depth
control_guidance_end=0.8,
num_inference_steps=30,
guidance_scale=3.5,
generator=torch.Generator(device="cuda").manual_seed(seed),
).images[0]
progress(0.9, desc="Applying mask and resizing...")
# Convert to tensor and apply mask
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda")
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
# Resize to target dimensions
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
# Apply mask (blend with black background)
background_tensor = torch.zeros_like(rgb_tensor)
if edge_refinement:
# replace edge with inner values
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor)
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
# Convert back to PIL Image
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
rgb_tensor = rgb_tensor / 255.0
to_img = ToPILImage()
condition_image = to_img(rgb_tensor.cpu())
progress(1, desc="FLUX condition image generated successfully.")
return condition_image
def refine_image_edges(rgb_tensor, mask_tensor):
"""
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries.
Algorithm:
1. Erode mask to get eroded_mask
2. Double erode mask to get double_eroded_mask
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask
4. Use circle_valid_mask to extract circle_rgb (clean edge values)
5. Dilate circle_rgb to cover the edge region
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1]
:return: refined_rgb_tensor
"""
# Convert tensors to numpy for OpenCV processing
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C)
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range
# Create morphological kernel (3x3 as requested)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
# Step 1: Erode mask to get eroded_mask
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3)
# Step 2: Double erode mask to get double_eroded_mask
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5)
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np)
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values)
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8)
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly)
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8)
# Step 6: Final composition
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere
refined_rgb_np = (rgb_np * double_eroded_mask_3c +
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8)
# Convert refined RGB back to tensor
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
return refined_rgb_tensor
@spaces.GPU(duration=120)
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
"""
Generate the image condition based on the selected view's silhouette and text prompt.
:param position_imgs: Position images from different views.
:param normal_imgs: Normal images from different views.
:param mask_imgs: Mask images from different views.
:param w2c: World-to-camera transformation matrices.
:param text_prompt: The text prompt for image generation.
:param selected_view: The selected view for image generation.
:param seed: Random seed for image generation.
:param model: The image generation model type, supports "SDXL" and "FLUX".
:param progress: Progress callback for Gradio.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
:return: Generated condition image and status message.
"""
progress(0, desc="Handling geometry information...")
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
depth_img = silhouette[0]
normal_img = silhouette[1]
mask = silhouette[2]
try:
if model == "SDXL":
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
return condition, "SDXL condition generated successfully."
elif model == "FLUX":
# FLUX only supports depth control, not normal
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
return condition, "FLUX condition generated successfully (depth-only control)."
else:
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.")
finally:
torch.cuda.empty_cache()