import os from PIL import Image from tqdm import tqdm import glob # --- 1. Centralized Configuration --- class CFG: # --- Paths --- SOURCE_ROOT = "SEN-2 LULC" DEST_ROOT = "SEN-2_LULC_preprocessed1" # --- Image & File Settings --- IMG_SIZE = 256 IMG_EXT = ".png" # Assumes source images are PNGs MASK_EXT = ".tif" # Assumes source masks are TIFs # --- 2. Define All Data Splits --- SPLITS = { "train": { "images": os.path.join(CFG.SOURCE_ROOT, "train_images", "train"), "masks": os.path.join(CFG.SOURCE_ROOT, "train_masks", "train"), }, "val": { "images": os.path.join(CFG.SOURCE_ROOT, "val_images", "val"), "masks": os.path.join(CFG.SOURCE_ROOT, "val_masks", "val"), }, "test": { # Test set usually only has images "images": os.path.join(CFG.SOURCE_ROOT, "test_images", "test"), "masks": None, # Set to None if no masks exist }, } def preprocess_and_resize(): """ Resizes images and their corresponding masks for all data splits. This version robustly matches images to masks by filename. """ print(f"Starting preprocessing. Output will be saved to: {CFG.DEST_ROOT}\n") stats = {} # --- 3. Iterate Through Each Split (train, val, test) --- for split_name, split_paths in SPLITS.items(): print(f"--- Processing split: {split_name} ---") # Get paths and create destination directories image_dir = split_paths["images"] mask_dir = split_paths["masks"] dest_img_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_images") os.makedirs(dest_img_dir, exist_ok=True) if mask_dir: dest_mask_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_masks") os.makedirs(dest_mask_dir, exist_ok=True) # Find all source images image_files = glob.glob(os.path.join(image_dir, f"*{CFG.IMG_EXT}")) if not image_files: print(f"Warning: No images found in {image_dir}. Skipping.") continue stats[split_name] = {"images": 0, "masks": 0} # --- 4. Robustly Process Each Image and Find its Mask --- for img_path in tqdm(image_files, desc=f"Resizing {split_name} data"): try: # --- Process the image --- base_name = os.path.basename(img_path) with Image.open(img_path) as img: resized_img = img.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.LANCZOS) resized_img.save(os.path.join(dest_img_dir, base_name)) stats[split_name]["images"] += 1 # --- Find and process the corresponding mask (if applicable) --- if mask_dir: mask_name = os.path.splitext(base_name)[0] + CFG.MASK_EXT mask_path = os.path.join(mask_dir, mask_name) if os.path.exists(mask_path): with Image.open(mask_path) as mask: # CRITICAL: Use NEAREST resampling for masks to preserve class labels resized_mask = mask.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.NEAREST) resized_mask.save(os.path.join(dest_mask_dir, mask_name)) stats[split_name]["masks"] += 1 except Exception as e: print(f"Error processing {img_path}: {e}") # --- 5. Final Verification Summary --- print("\n--- Preprocessing Complete! ---") for split_name, counts in stats.items(): print(f"Split '{split_name}': Processed {counts['images']} images and {counts['masks']} masks.") print(f"Resized data is in '{CFG.DEST_ROOT}'") if __name__ == "__main__": preprocess_and_resize()