# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from PIL import Image from torchvision import transforms as TF def load_and_preprocess_images(image_path_list): """ A quick start function to load and preprocess images for model input. This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. Args: image_path_list (list): List of paths to image files Returns: torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) Raises: ValueError: If the input list is empty Notes: - Images with different dimensions will be padded with white (value=1.0) - A warning is printed when images have different shapes - The function ensures width=518px while maintaining aspect ratio - Height is adjusted to be divisible by 14 for compatibility with model requirements """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") images = [] shapes = set() to_tensor = TF.ToTensor() # First process all images and collect their shapes for image_path in image_path_list: # Open image img = Image.open(image_path) # If there's an alpha channel, blend onto white background: if img.mode == "RGBA": # Create white background background = Image.new("RGBA", img.size, (255, 255, 255, 255)) # Alpha composite onto the white background img = Image.alpha_composite(background, img) # Now convert to "RGB" (this step assigns white for transparent areas) img = img.convert("RGB") width, height = img.size new_width = 518 # Calculate height maintaining aspect ratio, divisible by 14 new_height = round(height * (new_width / width) / 14) * 14 # Resize with new dimensions (width, height) img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = to_tensor(img) # Convert to tensor (0, 1) # Center crop height if it's larger than 518 if new_height > 518: start_y = (new_height - 518) // 2 img = img[:, start_y : start_y + 518, :] shapes.add((img.shape[1], img.shape[2])) images.append(img) # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) # Pad images if necessary padded_images = [] for img in images: h_padding = max_height - img.shape[1] w_padding = max_width - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) padded_images.append(img) images = padded_images images = torch.stack(images) # concatenate images # Ensure correct shape when single image if len(image_path_list) == 1: # Verify shape is (1, C, H, W) if images.dim() == 3: images = images.unsqueeze(0) return images