File size: 3,895 Bytes
8f5f46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from PIL import Image
from tqdm import tqdm
import glob

# --- 1. Centralized Configuration ---
class CFG:
    # --- Paths ---
    SOURCE_ROOT = "SEN-2 LULC"
    DEST_ROOT = "SEN-2_LULC_preprocessed1"
    
    # --- Image & File Settings ---
    IMG_SIZE = 256
    IMG_EXT = ".png"  # Assumes source images are PNGs
    MASK_EXT = ".tif" # Assumes source masks are TIFs

# --- 2. Define All Data Splits ---
SPLITS = {
    "train": {
        "images": os.path.join(CFG.SOURCE_ROOT, "train_images", "train"),
        "masks": os.path.join(CFG.SOURCE_ROOT, "train_masks", "train"),
    },
    "val": {
        "images": os.path.join(CFG.SOURCE_ROOT, "val_images", "val"),
        "masks": os.path.join(CFG.SOURCE_ROOT, "val_masks", "val"),
    },
    "test": {
        # Test set usually only has images
        "images": os.path.join(CFG.SOURCE_ROOT, "test_images", "test"),
        "masks": None, # Set to None if no masks exist
    },
}

def preprocess_and_resize():
    """
    Resizes images and their corresponding masks for all data splits.
    This version robustly matches images to masks by filename.
    """
    print(f"Starting preprocessing. Output will be saved to: {CFG.DEST_ROOT}\n")
    stats = {}

    # --- 3. Iterate Through Each Split (train, val, test) ---
    for split_name, split_paths in SPLITS.items():
        print(f"--- Processing split: {split_name} ---")
        
        # Get paths and create destination directories
        image_dir = split_paths["images"]
        mask_dir = split_paths["masks"]
        
        dest_img_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_images")
        os.makedirs(dest_img_dir, exist_ok=True)
        
        if mask_dir:
            dest_mask_dir = os.path.join(CFG.DEST_ROOT, f"{split_name}_masks")
            os.makedirs(dest_mask_dir, exist_ok=True)

        # Find all source images
        image_files = glob.glob(os.path.join(image_dir, f"*{CFG.IMG_EXT}"))
        
        if not image_files:
            print(f"Warning: No images found in {image_dir}. Skipping.")
            continue
            
        stats[split_name] = {"images": 0, "masks": 0}

        # --- 4. Robustly Process Each Image and Find its Mask ---
        for img_path in tqdm(image_files, desc=f"Resizing {split_name} data"):
            try:
                # --- Process the image ---
                base_name = os.path.basename(img_path)
                with Image.open(img_path) as img:
                    resized_img = img.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.LANCZOS)
                    resized_img.save(os.path.join(dest_img_dir, base_name))
                    stats[split_name]["images"] += 1

                # --- Find and process the corresponding mask (if applicable) ---
                if mask_dir:
                    mask_name = os.path.splitext(base_name)[0] + CFG.MASK_EXT
                    mask_path = os.path.join(mask_dir, mask_name)
                    
                    if os.path.exists(mask_path):
                        with Image.open(mask_path) as mask:
                            # CRITICAL: Use NEAREST resampling for masks to preserve class labels
                            resized_mask = mask.resize((CFG.IMG_SIZE, CFG.IMG_SIZE), resample=Image.Resampling.NEAREST)
                            resized_mask.save(os.path.join(dest_mask_dir, mask_name))
                            stats[split_name]["masks"] += 1

            except Exception as e:
                print(f"Error processing {img_path}: {e}")
    
    # --- 5. Final Verification Summary ---
    print("\n--- Preprocessing Complete! ---")
    for split_name, counts in stats.items():
        print(f"Split '{split_name}': Processed {counts['images']} images and {counts['masks']} masks.")
    print(f"Resized data is in '{CFG.DEST_ROOT}'")


if __name__ == "__main__":
    preprocess_and_resize()