|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms as TF |
|
|
|
def load_and_preprocess_images(image_path_list, mode="crop"): |
|
""" |
|
A quick start function to load and preprocess images for model input. |
|
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. |
|
|
|
Args: |
|
image_path_list (list): List of paths to image files |
|
mode (str, optional): Preprocessing mode, either "crop" or "pad". |
|
- "crop" (default): Sets width to 518px and center crops height if needed. |
|
- "pad": Preserves all pixels by making the largest dimension 518px |
|
and padding the smaller dimension to reach a square shape. |
|
|
|
Returns: |
|
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) |
|
|
|
Raises: |
|
ValueError: If the input list is empty or if mode is invalid |
|
|
|
Notes: |
|
- Images with different dimensions will be padded with white (value=1.0) |
|
- A warning is printed when images have different shapes |
|
- When mode="crop": The function ensures width=518px while maintaining aspect ratio |
|
and height is center-cropped if larger than 518px |
|
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio |
|
and the smaller dimension is padded to reach a square shape (518x518) |
|
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements |
|
""" |
|
|
|
if len(image_path_list) == 0: |
|
raise ValueError("At least 1 image is required") |
|
|
|
|
|
if mode not in ["crop", "pad"]: |
|
raise ValueError("Mode must be either 'crop' or 'pad'") |
|
|
|
images = [] |
|
shapes = set() |
|
to_tensor = TF.ToTensor() |
|
target_size = 448 |
|
|
|
|
|
for image_path in image_path_list: |
|
|
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
if img.mode == "RGBA": |
|
|
|
background = Image.new("RGBA", img.size, (255, 255, 255, 255)) |
|
|
|
img = Image.alpha_composite(background, img) |
|
|
|
|
|
img = img.convert("RGB") |
|
|
|
width, height = img.size |
|
|
|
if mode == "pad": |
|
|
|
if width >= height: |
|
new_width = target_size |
|
new_height = round(height * (new_width / width) / 14) * 14 |
|
else: |
|
new_height = target_size |
|
new_width = round(width * (new_height / height) / 14) * 14 |
|
else: |
|
|
|
new_width = target_size |
|
|
|
new_height = round(height * (new_width / width) / 14) * 14 |
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) |
|
img = to_tensor(img) |
|
|
|
|
|
if mode == "crop" and new_height > target_size: |
|
start_y = (new_height - target_size) // 2 |
|
img = img[:, start_y : start_y + target_size, :] |
|
|
|
|
|
if mode == "pad": |
|
h_padding = target_size - img.shape[1] |
|
w_padding = target_size - img.shape[2] |
|
|
|
if h_padding > 0 or w_padding > 0: |
|
pad_top = h_padding // 2 |
|
pad_bottom = h_padding - pad_top |
|
pad_left = w_padding // 2 |
|
pad_right = w_padding - pad_left |
|
|
|
|
|
img = torch.nn.functional.pad( |
|
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 |
|
) |
|
|
|
shapes.add((img.shape[1], img.shape[2])) |
|
images.append(img) |
|
|
|
|
|
|
|
if len(shapes) > 1: |
|
print(f"Warning: Found images with different shapes: {shapes}") |
|
|
|
max_height = max(shape[0] for shape in shapes) |
|
max_width = max(shape[1] for shape in shapes) |
|
|
|
|
|
padded_images = [] |
|
for img in images: |
|
h_padding = max_height - img.shape[1] |
|
w_padding = max_width - img.shape[2] |
|
|
|
if h_padding > 0 or w_padding > 0: |
|
pad_top = h_padding // 2 |
|
pad_bottom = h_padding - pad_top |
|
pad_left = w_padding // 2 |
|
pad_right = w_padding - pad_left |
|
|
|
img = torch.nn.functional.pad( |
|
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 |
|
) |
|
padded_images.append(img) |
|
images = padded_images |
|
|
|
images = torch.stack(images) |
|
|
|
|
|
if len(image_path_list) == 1: |
|
|
|
if images.dim() == 3: |
|
images = images.unsqueeze(0) |
|
|
|
return images |
|
|