import PIL import numpy as np from PIL import Image class Croper: def __init__( self, input_image: PIL.Image, target_mask: np.ndarray, ): self.input_image = input_image self.target_mask = target_mask def corp_mask_image(self): target_mask = self.target_mask input_image = self.input_image crop_length = 512 expand_size = 40 original_width, original_height = input_image.size mask_indices = np.where(target_mask) start_y = np.min(mask_indices[0]) - expand_size if start_y < 0: start_y = 0 end_y = np.max(mask_indices[0]) + expand_size if end_y > original_height: end_y = original_height start_x = np.min(mask_indices[1]) - expand_size if start_x < 0: start_x = 0 end_x = np.max(mask_indices[1]) + expand_size if end_x > original_width: end_x = original_width mask_height = end_y - start_y mask_width = end_x - start_x # choose the max side length max_side_length = max(mask_height, mask_width) # calculate the crop area crop_mask = target_mask[start_y:end_y, start_x:end_x] crop_mask_start_y = (max_side_length - mask_height) // 2 crop_mask_end_y = crop_mask_start_y + mask_height crop_mask_start_x = (max_side_length - mask_width) // 2 crop_mask_end_x = crop_mask_start_x + mask_width # create a square mask square_mask = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype) square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8)) crop_image = input_image.crop((start_x, start_y, end_x, end_y)) square_image = Image.new("RGB", (max_side_length, max_side_length)) square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y)) self.origin_start_x = start_x self.origin_start_y = start_y self.origin_end_x = end_x self.origin_end_y = end_y self.square_start_x = crop_mask_start_x self.square_start_y = crop_mask_start_y self.square_end_x = crop_mask_end_x self.square_end_y = crop_mask_end_y self.square_length = max_side_length self.square_mask_image = square_mask_image self.square_image = square_image self.resized_square_mask_image = square_mask_image.resize((crop_length, crop_length)) self.resized_square_image = square_image.resize((crop_length, crop_length)) return self.square_image, self.resized_square_image