import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from PIL import Image import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 from tqdm import tqdm import segmentation_models_pytorch as smp import cv2 # --- 1. Configuration --- class CFG: DATA_DIR = r"SEN-2_LULC_preprocessed" TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train_images") TRAIN_MASK_DIR = os.path.join(DATA_DIR, "train_masks") VAL_IMG_DIR = os.path.join(DATA_DIR, "val_images") VAL_MASK_DIR = os.path.join(DATA_DIR, "val_masks") OUTPUT_DIR = "./outputs_rgb_optimized" # The path for the 'best' model, for inference later MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "best_model_optimized.pth") # --- NEW: Path for the resumable checkpoint file --- CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "checkpoint.pth") PREDICTION_SAVE_PATH = os.path.join(OUTPUT_DIR, "predictions_optimized") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "CustomDeepLabV3+" ENCODER_NAME = "timm-efficientnet-b2" LOSS_FN_NAME = "DiceFocal" IN_CHANNELS = 3; NUM_CLASSES = 8; IMG_SIZE = 256 BATCH_SIZE = 4; ACCUMULATION_STEPS = 4 NUM_WORKERS = 8; LEARNING_RATE = 1e-4; EPOCHS = 50 SEED = 42; SUBSET_FRACTION = 0.75 # --- ARCHITECTURE and LOSS CLASSES (Unchanged) --- class SELayer(nn.Module): def __init__(self, channel, reduction=16): super(SELayer, self).__init__(); self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) def forward(self, x): b, c, _, _ = x.size(); y = self.avg_pool(x).view(b, c); y = self.fc(y).view(b, c, 1, 1); return x * y.expand_as(x) class CustomDeepLabV3Plus(nn.Module): def __init__(self, encoder_name, in_channels, classes): super().__init__(); self.smp_model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights="imagenet", in_channels=in_channels, classes=classes) decoder_channels = self.smp_model.segmentation_head[0].in_channels; self.se_layer = SELayer(decoder_channels) self.segmentation_head = self.smp_model.segmentation_head; self.smp_model.segmentation_head = nn.Identity() def forward(self, x): decoder_features = self.smp_model(x); attended_features = self.se_layer(decoder_features) output = self.segmentation_head(attended_features); return output class CombinedLoss(nn.Module): def __init__(self, loss1, loss2, alpha=0.5): super(CombinedLoss, self).__init__(); self.loss1 = loss1; self.loss2 = loss2; self.alpha = alpha self.name = f"{alpha}*{self.loss1.__class__.__name__} + {1-alpha}*{self.loss2.__class__.__name__}" def forward(self, prediction, target): loss1_val = self.loss1(prediction, target); loss2_val = self.loss2(prediction, target); return self.alpha * loss1_val + (1 - self.alpha) * loss2_val # --- DATASET and TRANSFORMS (Unchanged) --- class LULCDataset(Dataset): def __init__(self, image_dir, mask_dir, transform=None, subset_fraction=1.0): self.image_dir = image_dir; self.mask_dir = mask_dir; self.transform = transform all_images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')]) all_masks = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')]) num_samples = int(len(all_images) * subset_fraction) self.images = all_images[:num_samples]; self.masks = all_masks[:num_samples] assert len(self.images) == len(self.masks), "Mismatch"; print(f"Found {len(all_images)} total images, USING {len(self.images)} samples ({subset_fraction*100}%) from {image_dir}") def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = os.path.join(self.image_dir, self.images[idx]); mask_path = os.path.join(self.mask_dir, self.masks[idx]) image = np.array(Image.open(img_path).convert("RGB"), dtype=np.float32) mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) if self.transform: augmented = self.transform(image=image, mask=mask); image, mask = augmented['image'], augmented['mask'] return image, mask def get_transforms(img_size): DATASET_MEAN = [0.485, 0.456, 0.406]; DATASET_STD = [0.229, 0.224, 0.225] train_transform = A.Compose([A.Resize(img_size, img_size), A.Rotate(limit=35, p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) val_transform = A.Compose([A.Resize(img_size, img_size), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) return train_transform, val_transform # --- GET MODEL AND LOSS (Unchanged) --- def get_model(): if CFG.MODEL_NAME == "CustomDeepLabV3+": model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) else: model = smp.DeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, encoder_weights="imagenet", in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) return model.to(CFG.DEVICE) def get_loss_fn(): if CFG.LOSS_FN_NAME == "DiceFocal": dice = smp.losses.DiceLoss(mode='multiclass'); focal = smp.losses.FocalLoss(mode='multiclass'); return CombinedLoss(focal, dice, alpha=0.5) else: return smp.losses.DiceLoss(mode='multiclass') # --- Training and Evaluation Functions (Unchanged) --- def train_one_epoch(loader, model, optimizer, loss_fn, scaler): loop = tqdm(loader, desc="Training"); model.train(); optimizer.zero_grad() for batch_idx, (data, targets) in enumerate(loop): data = data.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) targets = targets.long().to(CFG.DEVICE, non_blocking=True) with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): predictions = model(data); loss = loss_fn(predictions, targets) / CFG.ACCUMULATION_STEPS scaler.scale(loss).backward() if (batch_idx + 1) % CFG.ACCUMULATION_STEPS == 0: scaler.step(optimizer); scaler.update(); optimizer.zero_grad() loop.set_postfix(loss=loss.item() * CFG.ACCUMULATION_STEPS) def evaluate_model(loader, model, loss_fn): model.eval(); intersection, union = torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE), torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE) pixel_correct, pixel_total, total_loss = 0, 0, 0 with torch.no_grad(): loop = tqdm(loader, desc="Evaluating") for x, y in loop: x = x.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) y = y.to(CFG.DEVICE, non_blocking=True).long() with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): preds = model(x); loss = loss_fn(preds, y); total_loss += loss.item() pred_labels = torch.argmax(preds, dim=1); pixel_correct += (pred_labels == y).sum(); pixel_total += torch.numel(y) for cls in range(CFG.NUM_CLASSES): pred_mask = (pred_labels == cls); true_mask = (y == cls); intersection[cls] += (pred_mask & true_mask).sum(); union[cls] += (pred_mask | true_mask).sum() pixel_acc = (pixel_correct / pixel_total) * 100; iou_per_class = (intersection + 1e-6) / (union + 1e-6) mean_iou = iou_per_class.mean(); avg_loss = total_loss / len(loader) print(f"Validation Results -> Avg Loss: {avg_loss:.4f}, Pixel Acc: {pixel_acc:.2f}%, mIoU: {mean_iou:.4f}") for i, iou in enumerate(iou_per_class): print(f" Class {i} IoU: {iou:.4f}") return mean_iou def save_predictions_as_images(loader, model): # This function is not part of the training loop, no changes needed. pass # implementation is correct as-is # --- NEW: Helper function to save a checkpoint --- def save_checkpoint(state, filename="checkpoint.pth"): print("=> Saving checkpoint") torch.save(state, filename) def main(): torch.manual_seed(CFG.SEED); np.random.seed(CFG.SEED); os.makedirs(CFG.OUTPUT_DIR, exist_ok=True) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True train_transform, val_transform = get_transforms(CFG.IMG_SIZE) train_ds = LULCDataset(CFG.TRAIN_IMG_DIR, CFG.TRAIN_MASK_DIR, transform=train_transform, subset_fraction=CFG.SUBSET_FRACTION) val_ds = LULCDataset(CFG.VAL_IMG_DIR, CFG.VAL_MASK_DIR, transform=val_transform, subset_fraction=CFG.SUBSET_FRACTION) train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=True, persistent_workers=True) val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=False, persistent_workers=True) model = get_model() model = model.to(memory_format=torch.channels_last) loss_fn = get_loss_fn() optimizer = optim.AdamW(model.parameters(), lr=CFG.LEARNING_RATE) scaler = torch.amp.GradScaler(enabled=(CFG.DEVICE=="cuda")) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS, eta_min=1e-6) # --- NEW: Logic to load checkpoint and resume training --- start_epoch = 0 best_val_miou = -1.0 if os.path.exists(CFG.CHECKPOINT_PATH): print(f"=> Loading checkpoint '{CFG.CHECKPOINT_PATH}'") checkpoint = torch.load(CFG.CHECKPOINT_PATH, map_location=CFG.DEVICE) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) scaler.load_state_dict(checkpoint['scaler_state_dict']) start_epoch = checkpoint['epoch'] + 1 best_val_miou = checkpoint['best_val_miou'] print(f"=> Resuming training from epoch {start_epoch}") else: print("=> No checkpoint found, starting new training session.") # --- MODIFIED: Main training loop now starts from the correct epoch --- for epoch in range(start_epoch, CFG.EPOCHS): print(f"\n--- Epoch {epoch+1}/{CFG.EPOCHS} ---") train_one_epoch(train_loader, model, optimizer, loss_fn, scaler) current_miou = evaluate_model(val_loader, model, loss_fn) scheduler.step() # Create the checkpoint dictionary with the complete state checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_miou': best_val_miou } if current_miou > best_val_miou: best_val_miou = current_miou checkpoint['best_val_miou'] = best_val_miou # Update best score in checkpoint print(f"🎉 New best mIoU: {best_val_miou:.4f}! Saving best model to {CFG.MODEL_SAVE_PATH}") torch.save(model.state_dict(), CFG.MODEL_SAVE_PATH) # Save just the model for easy inference # Save the full state checkpoint after every epoch save_checkpoint(checkpoint, filename=CFG.CHECKPOINT_PATH) print("\n--- Training Complete. Saving final predictions. ---") # Load the best performing model for final predictions model.load_state_dict(torch.load(CFG.MODEL_SAVE_PATH)) # Note: You may want a separate test_loader for final unbiased evaluation save_predictions_as_images(val_loader, model) if __name__ == "__main__": main()