Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
from tqdm import tqdm | |
import glob | |
# --- 1. Centralized Configuration --- | |
class CFG: | |
# --- Paths --- | |
SOURCE_ROOT = "SEN-2 LULC" | |
DEST_ROOT = "SEN-2_LULC_preprocessed1" | |
# --- Image & File Settings --- | |
IMG_SIZE = 256 | |
IMG_EXT = ".png" # Assumes source images are PNGs | |
MASK_EXT = ".tif" # Assumes source masks are TIFs | |
# --- 2. Define All Data Splits --- | |
SPLITS = { | |
"train": { | |
"images": os.path.join(CFG.SOURCE_ROOT, "train_images", "train"), | |
"masks": os.path.join(CFG.SOURCE_ROOT, "train_masks", "train"), | |
}, | |
"val": { | |
"images": os.path.join(CFG.SOURCE_ROOT, "val_images", "val"), | |
"masks": os.path.join(CFG.SOURCE_ROOT, "val_masks", "val"), | |
}, | |
"test": { | |
# Test set usually only has images | |
"images": os.path.join(CFG.SOURCE_ROOT, "test_images", "test"), | |
"masks": None, # Set to None if no masks exist | |
}, | |
} | |
def preprocess_and_resize(): | |
""" | |
Resizes images and their corresponding masks for all data splits. | |
This version robustly matches images to masks by filename. | |
""" | |
print(f"Starting preprocessing. Output will be saved to: {CFG.DEST_ROOT}\n") | |
stats = {} | |
# --- 3. Iterate Through Each Split (train, val, test) --- | |
for split_name, split_paths in SPLITS.items(): | |
print(f"--- Processing split: {split_name} ---") | |
# Get paths and create destination directories | |
image_dir = split_paths["images"] | |
mask_dir = split_paths["masks"] | |
dest_img_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_images") | |
os.makedirs(dest_img_dir, exist_ok=True) | |
if mask_dir: | |
dest_mask_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_masks") | |
os.makedirs(dest_mask_dir, exist_ok=True) | |
# Find all source images | |
image_files = glob.glob(os.path.join(image_dir, f"*{CFG.IMG_EXT}")) | |
if not image_files: | |
print(f"Warning: No images found in {image_dir}. Skipping.") | |
continue | |
stats[split_name] = {"images": 0, "masks": 0} | |
# --- 4. Robustly Process Each Image and Find its Mask --- | |
for img_path in tqdm(image_files, desc=f"Resizing {split_name} data"): | |
try: | |
# --- Process the image --- | |
base_name = os.path.basename(img_path) | |
with Image.open(img_path) as img: | |
resized_img = img.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.LANCZOS) | |
resized_img.save(os.path.join(dest_img_dir, base_name)) | |
stats[split_name]["images"] += 1 | |
# --- Find and process the corresponding mask (if applicable) --- | |
if mask_dir: | |
mask_name = os.path.splitext(base_name)[0] + CFG.MASK_EXT | |
mask_path = os.path.join(mask_dir, mask_name) | |
if os.path.exists(mask_path): | |
with Image.open(mask_path) as mask: | |
# CRITICAL: Use NEAREST resampling for masks to preserve class labels | |
resized_mask = mask.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.NEAREST) | |
resized_mask.save(os.path.join(dest_mask_dir, mask_name)) | |
stats[split_name]["masks"] += 1 | |
except Exception as e: | |
print(f"Error processing {img_path}: {e}") | |
# --- 5. Final Verification Summary --- | |
print("\n--- Preprocessing Complete! ---") | |
for split_name, counts in stats.items(): | |
print(f"Split '{split_name}': Processed {counts['images']} images and {counts['masks']} masks.") | |
print(f"Resized data is in '{CFG.DEST_ROOT}'") | |
if __name__ == "__main__": | |
preprocess_and_resize() |