import os import glob from PIL import Image import numpy as np import cv2 # OpenCV for image loading/processing import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, random_split import torchvision.transforms.functional as TF import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping import albumentations as A from albumentations.pytorch import ToTensorV2 import segmentation_models_pytorch as smp from torchmetrics import JaccardIndex from torchmetrics.segmentation import DiceScore # --- Configuration --- IMG_DIR = "derm_images_flat" MASK_DIR = "derm_mask_images_flat" MASK_SUFFIX = "_segmentation" # Part added to image name to get mask name IMG_SIZE = (256, 256) # Resize images/masks to this size BATCH_SIZE = 8 NUM_WORKERS = os.cpu_count() // 2 LEARNING_RATE = 1e-4 # Initial LR, will be tuned MAX_EPOCHS = 5 VAL_SPLIT = 0.15 # 15% for validation PATIENCE = 5 # For early stopping ACCELERATOR = "gpu" if torch.cuda.is_available() else "cpu" DEVICES = 1 if torch.cuda.is_available() else None PRECISION = 16 if torch.cuda.is_available() else 32 # Use mixed precision if GPU supports it # --- Dataset --- class DermDataset(Dataset): def __init__(self, image_paths, mask_dir, mask_suffix, transform=None): self.image_paths = image_paths self.mask_dir = mask_dir self.mask_suffix = mask_suffix self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] img_filename = os.path.basename(img_path) img_name_part, img_ext = os.path.splitext(img_filename) # Construct mask path - try common extensions like .png mask_filename_base = f"{img_name_part}{self.mask_suffix}" possible_mask_paths = glob.glob(os.path.join(self.mask_dir, f"{mask_filename_base}.*")) if not possible_mask_paths: raise FileNotFoundError(f"Mask not found for image: {img_path}. Tried pattern: {mask_filename_base}.* in {self.mask_dir}") mask_path = possible_mask_paths[0] # Assume first found is the correct one # Load image (ensure RGB) image = cv2.imread(img_path) if image is None: raise IOError(f"Could not read image: {img_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Load mask (ensure grayscale) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: raise IOError(f"Could not read mask: {mask_path}") # Preprocess mask: ensure binary 0 or 1, add channel dim mask = (mask > 128).astype(np.float32) # Threshold and convert to float # mask = np.expand_dims(mask, axis=-1) # Add channel dim if needed by transforms/loss # Apply transformations if self.transform: augmented = self.transform(image=image, mask=mask) image = augmented['image'] mask = augmented['mask'] # Add channel dimension FOR THE MASK after albumentations if needed # For BCEWithLogitsLoss with single class output, mask should be [B, 1, H, W] mask = mask.unsqueeze(0) # Add channel dimension -> [1, H, W] return {"image": image, "mask": mask} # --- Transforms --- def get_transforms(img_size, is_train=True): if is_train: # Augmentations for training return A.Compose([ A.Resize(height=img_size[0], width=img_size[1]), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), A.Affine(scale=(0.9, 1.1), translate_percent=0.0625, rotate=(-15, 15), p=0.5, cval=0), A.OneOf([ A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05), A.GridDistortion(p=0.5), A.OpticalDistortion(distort_limit=0.5, p=1) ], p=0.3), A.RandomBrightnessContrast(p=0.3), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # ImageNet stats ToTensorV2(), # Converts image HWC->CHW, mask HW->HW (need to add C dim later) ]) else: # Validation/Test: Just resize and normalize return A.Compose([ A.Resize(height=img_size[0], width=img_size[1]), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) # --- Lightning Module --- class UNetLitModule(pl.LightningModule): def __init__(self, learning_rate=1e-4): super().__init__() self.learning_rate = learning_rate self.save_hyperparameters() # Saves args like learning_rate to hparams # --- Model --- # Using segmentation_models_pytorch self.model = smp.Unet( encoder_name="resnet34", # Choose backbone encoder_weights="imagenet", # Use pretrained weights in_channels=3, # Input channels (RGB) classes=1, # Output channels (binary mask) # activation='sigmoid' # Sigmoid usually applied *after* loss ) # --- Loss Function --- # BCEWithLogitsLoss is numerically stable for binary classification self.loss_fn = nn.BCEWithLogitsLoss() # --- Metrics --- # Jaccard Index (IoU) for Segmentation self.iou_metric = JaccardIndex(task="binary", threshold=0.5) # Threshold output probabilities def forward(self, x): return self.model(x) def _common_step(self, batch, batch_idx, stage): images = batch["image"] masks = batch["mask"] logits = self(images) # Model output (before activation) loss = self.loss_fn(logits, masks) # Calculate metrics # Apply sigmoid before calculating metrics as they expect probabilities preds = torch.sigmoid(logits) iou = self.iou_metric(preds, masks.int()) # JaccardIndex expects integer masks self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log(f"{stage}_iou", iou, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def training_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "val") def test_step(self, batch, batch_idx): # Optional: If you have a separate test set return self._common_step(batch, batch_idx, "test") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) # Optional: Add a learning rate scheduler # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) # return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} return optimizer # --- Main Training Script --- if __name__ == "__main__": pl.seed_everything(42) # for reproducibility # --- Setup Data --- all_image_paths = sorted(glob.glob(os.path.join(IMG_DIR, "*.*"))) # Find all image files if not all_image_paths: raise FileNotFoundError(f"No images found in {IMG_DIR}") # Split data n_total = len(all_image_paths) n_val = int(n_total * VAL_SPLIT) n_train = n_total - n_val if n_train == 0 or n_val == 0: raise ValueError(f"Train ({n_train}) or Val ({n_val}) split has 0 samples. Check VAL_SPLIT and dataset size.") train_paths, val_paths = random_split(all_image_paths, [n_train, n_val]) train_dataset = DermDataset(list(train_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=True)) val_dataset = DermDataset(list(val_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=False)) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) print(f"Found {n_total} images. Training on {len(train_dataset)}, Validating on {len(val_dataset)}.") # --- Initialize Model --- # Instantiate with a placeholder LR first for LR finder model = UNetLitModule(learning_rate=LEARNING_RATE) # --- Callbacks --- checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", filename="unet-derm-{epoch:02d}-{val_iou:.4f}", save_top_k=1, verbose=True, monitor="val_iou", # Save based on validation IoU mode="max" # Maximize IoU ) lr_monitor = LearningRateMonitor(logging_interval='step') early_stop_callback = EarlyStopping( monitor="val_iou", # Monitor validation IoU patience=PATIENCE, verbose=True, mode="max" # Stop if IoU stops improving ) logger = TensorBoardLogger("tb_logs", name="unet_derm_resnet34") # --- Trainer --- trainer = pl.Trainer( logger=logger, callbacks=[checkpoint_callback, lr_monitor, early_stop_callback], max_epochs=MAX_EPOCHS, accelerator=ACCELERATOR, devices=DEVICES, precision=PRECISION, log_every_n_steps=10, # deterministic=True, # Might slow down training ) # --- Find Optimal Learning Rate --- print("\nFinding optimal learning rate...") tuner = pl.tuner.Tuner(trainer) lr_finder_result = tuner.lr_find(model, train_dataloaders=train_loader, val_dataloaders=val_loader, num_training=100) # Run LR finder for 100 steps # Inspect results and pick learning rate fig = lr_finder_result.plot(suggest=True) fig.show() # Display plot suggested_lr = lr_finder_result.suggestion() if suggested_lr is not None: print(f"Suggested LR: {suggested_lr:.2e}") model.hparams.learning_rate = suggested_lr # Update model's hparam print(f"Using LR: {model.hparams.learning_rate:.2e}") else: print(f"LR finder did not suggest a rate. Using initial LR: {model.hparams.learning_rate:.2e}") # --- Start Training --- print("\nStarting training...") trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) print("\nTraining finished.") print(f"Best model saved at: {checkpoint_callback.best_model_path}") # --- Save final model state dict separately (optional, sometimes easier for inference) --- final_model_path = "unet_derm_final_model.pth" # Load best model before saving state dict best_model = UNetLitModule.load_from_checkpoint(checkpoint_callback.best_model_path) torch.save(best_model.model.state_dict(), final_model_path) print(f"Final model state_dict saved to: {final_model_path}")