Spaces:
Sleeping
Sleeping
import os | |
import glob | |
from PIL import Image | |
import numpy as np | |
import cv2 # OpenCV for image loading/processing | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader, random_split | |
import torchvision.transforms.functional as TF | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import segmentation_models_pytorch as smp | |
from torchmetrics import JaccardIndex | |
from torchmetrics.segmentation import DiceScore | |
# --- Configuration --- | |
IMG_DIR = "derm_images_flat" | |
MASK_DIR = "derm_mask_images_flat" | |
MASK_SUFFIX = "_segmentation" # Part added to image name to get mask name | |
IMG_SIZE = (256, 256) # Resize images/masks to this size | |
BATCH_SIZE = 8 | |
NUM_WORKERS = os.cpu_count() // 2 | |
LEARNING_RATE = 1e-4 # Initial LR, will be tuned | |
MAX_EPOCHS = 5 | |
VAL_SPLIT = 0.15 # 15% for validation | |
PATIENCE = 5 # For early stopping | |
ACCELERATOR = "gpu" if torch.cuda.is_available() else "cpu" | |
DEVICES = 1 if torch.cuda.is_available() else None | |
PRECISION = 16 if torch.cuda.is_available() else 32 # Use mixed precision if GPU supports it | |
# --- Dataset --- | |
class DermDataset(Dataset): | |
def __init__(self, image_paths, mask_dir, mask_suffix, transform=None): | |
self.image_paths = image_paths | |
self.mask_dir = mask_dir | |
self.mask_suffix = mask_suffix | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
img_path = self.image_paths[idx] | |
img_filename = os.path.basename(img_path) | |
img_name_part, img_ext = os.path.splitext(img_filename) | |
# Construct mask path - try common extensions like .png | |
mask_filename_base = f"{img_name_part}{self.mask_suffix}" | |
possible_mask_paths = glob.glob(os.path.join(self.mask_dir, f"{mask_filename_base}.*")) | |
if not possible_mask_paths: | |
raise FileNotFoundError(f"Mask not found for image: {img_path}. Tried pattern: {mask_filename_base}.* in {self.mask_dir}") | |
mask_path = possible_mask_paths[0] # Assume first found is the correct one | |
# Load image (ensure RGB) | |
image = cv2.imread(img_path) | |
if image is None: | |
raise IOError(f"Could not read image: {img_path}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Load mask (ensure grayscale) | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
if mask is None: | |
raise IOError(f"Could not read mask: {mask_path}") | |
# Preprocess mask: ensure binary 0 or 1, add channel dim | |
mask = (mask > 128).astype(np.float32) # Threshold and convert to float | |
# mask = np.expand_dims(mask, axis=-1) # Add channel dim if needed by transforms/loss | |
# Apply transformations | |
if self.transform: | |
augmented = self.transform(image=image, mask=mask) | |
image = augmented['image'] | |
mask = augmented['mask'] | |
# Add channel dimension FOR THE MASK after albumentations if needed | |
# For BCEWithLogitsLoss with single class output, mask should be [B, 1, H, W] | |
mask = mask.unsqueeze(0) # Add channel dimension -> [1, H, W] | |
return {"image": image, "mask": mask} | |
# --- Transforms --- | |
def get_transforms(img_size, is_train=True): | |
if is_train: | |
# Augmentations for training | |
return A.Compose([ | |
A.Resize(height=img_size[0], width=img_size[1]), | |
A.HorizontalFlip(p=0.5), | |
A.VerticalFlip(p=0.5), | |
A.RandomRotate90(p=0.5), | |
A.Affine(scale=(0.9, 1.1), translate_percent=0.0625, rotate=(-15, 15), p=0.5, cval=0), | |
A.OneOf([ | |
A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05), | |
A.GridDistortion(p=0.5), | |
A.OpticalDistortion(distort_limit=0.5, p=1) | |
], p=0.3), | |
A.RandomBrightnessContrast(p=0.3), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # ImageNet stats | |
ToTensorV2(), # Converts image HWC->CHW, mask HW->HW (need to add C dim later) | |
]) | |
else: | |
# Validation/Test: Just resize and normalize | |
return A.Compose([ | |
A.Resize(height=img_size[0], width=img_size[1]), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
ToTensorV2(), | |
]) | |
# --- Lightning Module --- | |
class UNetLitModule(pl.LightningModule): | |
def __init__(self, learning_rate=1e-4): | |
super().__init__() | |
self.learning_rate = learning_rate | |
self.save_hyperparameters() # Saves args like learning_rate to hparams | |
# --- Model --- | |
# Using segmentation_models_pytorch | |
self.model = smp.Unet( | |
encoder_name="resnet34", # Choose backbone | |
encoder_weights="imagenet", # Use pretrained weights | |
in_channels=3, # Input channels (RGB) | |
classes=1, # Output channels (binary mask) | |
# activation='sigmoid' # Sigmoid usually applied *after* loss | |
) | |
# --- Loss Function --- | |
# BCEWithLogitsLoss is numerically stable for binary classification | |
self.loss_fn = nn.BCEWithLogitsLoss() | |
# --- Metrics --- | |
# Jaccard Index (IoU) for Segmentation | |
self.iou_metric = JaccardIndex(task="binary", threshold=0.5) # Threshold output probabilities | |
def forward(self, x): | |
return self.model(x) | |
def _common_step(self, batch, batch_idx, stage): | |
images = batch["image"] | |
masks = batch["mask"] | |
logits = self(images) # Model output (before activation) | |
loss = self.loss_fn(logits, masks) | |
# Calculate metrics | |
# Apply sigmoid before calculating metrics as they expect probabilities | |
preds = torch.sigmoid(logits) | |
iou = self.iou_metric(preds, masks.int()) # JaccardIndex expects integer masks | |
self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
self.log(f"{stage}_iou", iou, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def training_step(self, batch, batch_idx): | |
return self._common_step(batch, batch_idx, "train") | |
def validation_step(self, batch, batch_idx): | |
return self._common_step(batch, batch_idx, "val") | |
def test_step(self, batch, batch_idx): | |
# Optional: If you have a separate test set | |
return self._common_step(batch, batch_idx, "test") | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | |
# Optional: Add a learning rate scheduler | |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) | |
# return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} | |
return optimizer | |
# --- Main Training Script --- | |
if __name__ == "__main__": | |
pl.seed_everything(42) # for reproducibility | |
# --- Setup Data --- | |
all_image_paths = sorted(glob.glob(os.path.join(IMG_DIR, "*.*"))) # Find all image files | |
if not all_image_paths: | |
raise FileNotFoundError(f"No images found in {IMG_DIR}") | |
# Split data | |
n_total = len(all_image_paths) | |
n_val = int(n_total * VAL_SPLIT) | |
n_train = n_total - n_val | |
if n_train == 0 or n_val == 0: | |
raise ValueError(f"Train ({n_train}) or Val ({n_val}) split has 0 samples. Check VAL_SPLIT and dataset size.") | |
train_paths, val_paths = random_split(all_image_paths, [n_train, n_val]) | |
train_dataset = DermDataset(list(train_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=True)) | |
val_dataset = DermDataset(list(val_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=False)) | |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True) | |
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) | |
print(f"Found {n_total} images. Training on {len(train_dataset)}, Validating on {len(val_dataset)}.") | |
# --- Initialize Model --- | |
# Instantiate with a placeholder LR first for LR finder | |
model = UNetLitModule(learning_rate=LEARNING_RATE) | |
# --- Callbacks --- | |
checkpoint_callback = ModelCheckpoint( | |
dirpath="checkpoints", | |
filename="unet-derm-{epoch:02d}-{val_iou:.4f}", | |
save_top_k=1, | |
verbose=True, | |
monitor="val_iou", # Save based on validation IoU | |
mode="max" # Maximize IoU | |
) | |
lr_monitor = LearningRateMonitor(logging_interval='step') | |
early_stop_callback = EarlyStopping( | |
monitor="val_iou", # Monitor validation IoU | |
patience=PATIENCE, | |
verbose=True, | |
mode="max" # Stop if IoU stops improving | |
) | |
logger = TensorBoardLogger("tb_logs", name="unet_derm_resnet34") | |
# --- Trainer --- | |
trainer = pl.Trainer( | |
logger=logger, | |
callbacks=[checkpoint_callback, lr_monitor, early_stop_callback], | |
max_epochs=MAX_EPOCHS, | |
accelerator=ACCELERATOR, | |
devices=DEVICES, | |
precision=PRECISION, | |
log_every_n_steps=10, | |
# deterministic=True, # Might slow down training | |
) | |
# --- Find Optimal Learning Rate --- | |
print("\nFinding optimal learning rate...") | |
tuner = pl.tuner.Tuner(trainer) | |
lr_finder_result = tuner.lr_find(model, train_dataloaders=train_loader, val_dataloaders=val_loader, num_training=100) # Run LR finder for 100 steps | |
# Inspect results and pick learning rate | |
fig = lr_finder_result.plot(suggest=True) | |
fig.show() # Display plot | |
suggested_lr = lr_finder_result.suggestion() | |
if suggested_lr is not None: | |
print(f"Suggested LR: {suggested_lr:.2e}") | |
model.hparams.learning_rate = suggested_lr # Update model's hparam | |
print(f"Using LR: {model.hparams.learning_rate:.2e}") | |
else: | |
print(f"LR finder did not suggest a rate. Using initial LR: {model.hparams.learning_rate:.2e}") | |
# --- Start Training --- | |
print("\nStarting training...") | |
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) | |
print("\nTraining finished.") | |
print(f"Best model saved at: {checkpoint_callback.best_model_path}") | |
# --- Save final model state dict separately (optional, sometimes easier for inference) --- | |
final_model_path = "unet_derm_final_model.pth" | |
# Load best model before saving state dict | |
best_model = UNetLitModule.load_from_checkpoint(checkpoint_callback.best_model_path) | |
torch.save(best_model.model.state_dict(), final_model_path) | |
print(f"Final model state_dict saved to: {final_model_path}") |