Vishalpainjane's picture
added files
8f5f46d
import os
from PIL import Image
from tqdm import tqdm
import glob
# --- 1. Centralized Configuration ---
class CFG:
# --- Paths ---
SOURCE_ROOT = "SEN-2 LULC"
DEST_ROOT = "SEN-2_LULC_preprocessed1"
# --- Image & File Settings ---
IMG_SIZE = 256
IMG_EXT = ".png" # Assumes source images are PNGs
MASK_EXT = ".tif" # Assumes source masks are TIFs
# --- 2. Define All Data Splits ---
SPLITS = {
"train": {
"images": os.path.join(CFG.SOURCE_ROOT, "train_images", "train"),
"masks": os.path.join(CFG.SOURCE_ROOT, "train_masks", "train"),
},
"val": {
"images": os.path.join(CFG.SOURCE_ROOT, "val_images", "val"),
"masks": os.path.join(CFG.SOURCE_ROOT, "val_masks", "val"),
},
"test": {
# Test set usually only has images
"images": os.path.join(CFG.SOURCE_ROOT, "test_images", "test"),
"masks": None, # Set to None if no masks exist
},
}
def preprocess_and_resize():
"""
Resizes images and their corresponding masks for all data splits.
This version robustly matches images to masks by filename.
"""
print(f"Starting preprocessing. Output will be saved to: {CFG.DEST_ROOT}\n")
stats = {}
# --- 3. Iterate Through Each Split (train, val, test) ---
for split_name, split_paths in SPLITS.items():
print(f"--- Processing split: {split_name} ---")
# Get paths and create destination directories
image_dir = split_paths["images"]
mask_dir = split_paths["masks"]
dest_img_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_images")
os.makedirs(dest_img_dir, exist_ok=True)
if mask_dir:
dest_mask_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_masks")
os.makedirs(dest_mask_dir, exist_ok=True)
# Find all source images
image_files = glob.glob(os.path.join(image_dir, f"*{CFG.IMG_EXT}"))
if not image_files:
print(f"Warning: No images found in {image_dir}. Skipping.")
continue
stats[split_name] = {"images": 0, "masks": 0}
# --- 4. Robustly Process Each Image and Find its Mask ---
for img_path in tqdm(image_files, desc=f"Resizing {split_name} data"):
try:
# --- Process the image ---
base_name = os.path.basename(img_path)
with Image.open(img_path) as img:
resized_img = img.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.LANCZOS)
resized_img.save(os.path.join(dest_img_dir, base_name))
stats[split_name]["images"] += 1
# --- Find and process the corresponding mask (if applicable) ---
if mask_dir:
mask_name = os.path.splitext(base_name)[0] + CFG.MASK_EXT
mask_path = os.path.join(mask_dir, mask_name)
if os.path.exists(mask_path):
with Image.open(mask_path) as mask:
# CRITICAL: Use NEAREST resampling for masks to preserve class labels
resized_mask = mask.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.NEAREST)
resized_mask.save(os.path.join(dest_mask_dir, mask_name))
stats[split_name]["masks"] += 1
except Exception as e:
print(f"Error processing {img_path}: {e}")
# --- 5. Final Verification Summary ---
print("\n--- Preprocessing Complete! ---")
for split_name, counts in stats.items():
print(f"Split '{split_name}': Processed {counts['images']} images and {counts['masks']} masks.")
print(f"Resized data is in '{CFG.DEST_ROOT}'")
if __name__ == "__main__":
preprocess_and_resize()