Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
import numpy as np | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from tqdm import tqdm | |
import segmentation_models_pytorch as smp | |
import cv2 | |
# --- 1. Configuration --- | |
class CFG: | |
DATA_DIR = r"SEN-2_LULC_preprocessed" | |
TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train_images") | |
TRAIN_MASK_DIR = os.path.join(DATA_DIR, "train_masks") | |
VAL_IMG_DIR = os.path.join(DATA_DIR, "val_images") | |
VAL_MASK_DIR = os.path.join(DATA_DIR, "val_masks") | |
OUTPUT_DIR = "./outputs_rgb_optimized" | |
# The path for the 'best' model, for inference later | |
MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "best_model_optimized.pth") | |
# --- NEW: Path for the resumable checkpoint file --- | |
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "checkpoint.pth") | |
PREDICTION_SAVE_PATH = os.path.join(OUTPUT_DIR, "predictions_optimized") | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_NAME = "CustomDeepLabV3+" | |
ENCODER_NAME = "timm-efficientnet-b2" | |
LOSS_FN_NAME = "DiceFocal" | |
IN_CHANNELS = 3; NUM_CLASSES = 8; IMG_SIZE = 256 | |
BATCH_SIZE = 4; ACCUMULATION_STEPS = 4 | |
NUM_WORKERS = 8; LEARNING_RATE = 1e-4; EPOCHS = 50 | |
SEED = 42; SUBSET_FRACTION = 0.75 | |
# --- ARCHITECTURE and LOSS CLASSES (Unchanged) --- | |
class SELayer(nn.Module): | |
def __init__(self, channel, reduction=16): | |
super(SELayer, self).__init__(); self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) | |
def forward(self, x): | |
b, c, _, _ = x.size(); y = self.avg_pool(x).view(b, c); y = self.fc(y).view(b, c, 1, 1); return x * y.expand_as(x) | |
class CustomDeepLabV3Plus(nn.Module): | |
def __init__(self, encoder_name, in_channels, classes): | |
super().__init__(); self.smp_model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights="imagenet", in_channels=in_channels, classes=classes) | |
decoder_channels = self.smp_model.segmentation_head[0].in_channels; self.se_layer = SELayer(decoder_channels) | |
self.segmentation_head = self.smp_model.segmentation_head; self.smp_model.segmentation_head = nn.Identity() | |
def forward(self, x): | |
decoder_features = self.smp_model(x); attended_features = self.se_layer(decoder_features) | |
output = self.segmentation_head(attended_features); return output | |
class CombinedLoss(nn.Module): | |
def __init__(self, loss1, loss2, alpha=0.5): | |
super(CombinedLoss, self).__init__(); self.loss1 = loss1; self.loss2 = loss2; self.alpha = alpha | |
self.name = f"{alpha}*{self.loss1.__class__.__name__} + {1-alpha}*{self.loss2.__class__.__name__}" | |
def forward(self, prediction, target): | |
loss1_val = self.loss1(prediction, target); loss2_val = self.loss2(prediction, target); return self.alpha * loss1_val + (1 - self.alpha) * loss2_val | |
# --- DATASET and TRANSFORMS (Unchanged) --- | |
class LULCDataset(Dataset): | |
def __init__(self, image_dir, mask_dir, transform=None, subset_fraction=1.0): | |
self.image_dir = image_dir; self.mask_dir = mask_dir; self.transform = transform | |
all_images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')]) | |
all_masks = sorted([f for f in os.listdir(mask_dir) if f.endswith('.tif')]) | |
num_samples = int(len(all_images) * subset_fraction) | |
self.images = all_images[:num_samples]; self.masks = all_masks[:num_samples] | |
assert len(self.images) == len(self.masks), "Mismatch"; print(f"Found {len(all_images)} total images, USING {len(self.images)} samples ({subset_fraction*100}%) from {image_dir}") | |
def __len__(self): return len(self.images) | |
def __getitem__(self, idx): | |
img_path = os.path.join(self.image_dir, self.images[idx]); mask_path = os.path.join(self.mask_dir, self.masks[idx]) | |
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.float32) | |
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) | |
if self.transform: augmented = self.transform(image=image, mask=mask); image, mask = augmented['image'], augmented['mask'] | |
return image, mask | |
def get_transforms(img_size): | |
DATASET_MEAN = [0.485, 0.456, 0.406]; DATASET_STD = [0.229, 0.224, 0.225] | |
train_transform = A.Compose([A.Resize(img_size, img_size), A.Rotate(limit=35, p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) | |
val_transform = A.Compose([A.Resize(img_size, img_size), A.Normalize(mean=DATASET_MEAN, std=DATASET_STD), ToTensorV2()]) | |
return train_transform, val_transform | |
# --- GET MODEL AND LOSS (Unchanged) --- | |
def get_model(): | |
if CFG.MODEL_NAME == "CustomDeepLabV3+": model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) | |
else: model = smp.DeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, encoder_weights="imagenet", in_channels=CFG.IN_CHANNELS, classes=CFG.NUM_CLASSES) | |
return model.to(CFG.DEVICE) | |
def get_loss_fn(): | |
if CFG.LOSS_FN_NAME == "DiceFocal": dice = smp.losses.DiceLoss(mode='multiclass'); focal = smp.losses.FocalLoss(mode='multiclass'); return CombinedLoss(focal, dice, alpha=0.5) | |
else: return smp.losses.DiceLoss(mode='multiclass') | |
# --- Training and Evaluation Functions (Unchanged) --- | |
def train_one_epoch(loader, model, optimizer, loss_fn, scaler): | |
loop = tqdm(loader, desc="Training"); model.train(); optimizer.zero_grad() | |
for batch_idx, (data, targets) in enumerate(loop): | |
data = data.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) | |
targets = targets.long().to(CFG.DEVICE, non_blocking=True) | |
with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): | |
predictions = model(data); loss = loss_fn(predictions, targets) / CFG.ACCUMULATION_STEPS | |
scaler.scale(loss).backward() | |
if (batch_idx + 1) % CFG.ACCUMULATION_STEPS == 0: | |
scaler.step(optimizer); scaler.update(); optimizer.zero_grad() | |
loop.set_postfix(loss=loss.item() * CFG.ACCUMULATION_STEPS) | |
def evaluate_model(loader, model, loss_fn): | |
model.eval(); intersection, union = torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE), torch.zeros(CFG.NUM_CLASSES, device=CFG.DEVICE) | |
pixel_correct, pixel_total, total_loss = 0, 0, 0 | |
with torch.no_grad(): | |
loop = tqdm(loader, desc="Evaluating") | |
for x, y in loop: | |
x = x.to(CFG.DEVICE, non_blocking=True, memory_format=torch.channels_last) | |
y = y.to(CFG.DEVICE, non_blocking=True).long() | |
with torch.amp.autocast(device_type=CFG.DEVICE, dtype=torch.bfloat16, enabled=(CFG.DEVICE=="cuda")): | |
preds = model(x); loss = loss_fn(preds, y); total_loss += loss.item() | |
pred_labels = torch.argmax(preds, dim=1); pixel_correct += (pred_labels == y).sum(); pixel_total += torch.numel(y) | |
for cls in range(CFG.NUM_CLASSES): pred_mask = (pred_labels == cls); true_mask = (y == cls); intersection[cls] += (pred_mask & true_mask).sum(); union[cls] += (pred_mask | true_mask).sum() | |
pixel_acc = (pixel_correct / pixel_total) * 100; iou_per_class = (intersection + 1e-6) / (union + 1e-6) | |
mean_iou = iou_per_class.mean(); avg_loss = total_loss / len(loader) | |
print(f"Validation Results -> Avg Loss: {avg_loss:.4f}, Pixel Acc: {pixel_acc:.2f}%, mIoU: {mean_iou:.4f}") | |
for i, iou in enumerate(iou_per_class): print(f" Class {i} IoU: {iou:.4f}") | |
return mean_iou | |
def save_predictions_as_images(loader, model): | |
# This function is not part of the training loop, no changes needed. | |
pass # implementation is correct as-is | |
# --- NEW: Helper function to save a checkpoint --- | |
def save_checkpoint(state, filename="checkpoint.pth"): | |
print("=> Saving checkpoint") | |
torch.save(state, filename) | |
def main(): | |
torch.manual_seed(CFG.SEED); np.random.seed(CFG.SEED); os.makedirs(CFG.OUTPUT_DIR, exist_ok=True) | |
if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True | |
train_transform, val_transform = get_transforms(CFG.IMG_SIZE) | |
train_ds = LULCDataset(CFG.TRAIN_IMG_DIR, CFG.TRAIN_MASK_DIR, transform=train_transform, subset_fraction=CFG.SUBSET_FRACTION) | |
val_ds = LULCDataset(CFG.VAL_IMG_DIR, CFG.VAL_MASK_DIR, transform=val_transform, subset_fraction=CFG.SUBSET_FRACTION) | |
train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=True, persistent_workers=True) | |
val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, num_workers=CFG.NUM_WORKERS, pin_memory=True, shuffle=False, persistent_workers=True) | |
model = get_model() | |
model = model.to(memory_format=torch.channels_last) | |
loss_fn = get_loss_fn() | |
optimizer = optim.AdamW(model.parameters(), lr=CFG.LEARNING_RATE) | |
scaler = torch.amp.GradScaler(enabled=(CFG.DEVICE=="cuda")) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS, eta_min=1e-6) | |
# --- NEW: Logic to load checkpoint and resume training --- | |
start_epoch = 0 | |
best_val_miou = -1.0 | |
if os.path.exists(CFG.CHECKPOINT_PATH): | |
print(f"=> Loading checkpoint '{CFG.CHECKPOINT_PATH}'") | |
checkpoint = torch.load(CFG.CHECKPOINT_PATH, map_location=CFG.DEVICE) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
scaler.load_state_dict(checkpoint['scaler_state_dict']) | |
start_epoch = checkpoint['epoch'] + 1 | |
best_val_miou = checkpoint['best_val_miou'] | |
print(f"=> Resuming training from epoch {start_epoch}") | |
else: | |
print("=> No checkpoint found, starting new training session.") | |
# --- MODIFIED: Main training loop now starts from the correct epoch --- | |
for epoch in range(start_epoch, CFG.EPOCHS): | |
print(f"\n--- Epoch {epoch+1}/{CFG.EPOCHS} ---") | |
train_one_epoch(train_loader, model, optimizer, loss_fn, scaler) | |
current_miou = evaluate_model(val_loader, model, loss_fn) | |
scheduler.step() | |
# Create the checkpoint dictionary with the complete state | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict': scheduler.state_dict(), | |
'scaler_state_dict': scaler.state_dict(), | |
'best_val_miou': best_val_miou | |
} | |
if current_miou > best_val_miou: | |
best_val_miou = current_miou | |
checkpoint['best_val_miou'] = best_val_miou # Update best score in checkpoint | |
print(f"🎉 New best mIoU: {best_val_miou:.4f}! Saving best model to {CFG.MODEL_SAVE_PATH}") | |
torch.save(model.state_dict(), CFG.MODEL_SAVE_PATH) # Save just the model for easy inference | |
# Save the full state checkpoint after every epoch | |
save_checkpoint(checkpoint, filename=CFG.CHECKPOINT_PATH) | |
print("\n--- Training Complete. Saving final predictions. ---") | |
# Load the best performing model for final predictions | |
model.load_state_dict(torch.load(CFG.MODEL_SAVE_PATH)) | |
# Note: You may want a separate test_loader for final unbiased evaluation | |
save_predictions_as_images(val_loader, model) | |
if __name__ == "__main__": | |
main() |