Johnyquest7 commited on
Commit
01ff706
·
verified ·
1 Parent(s): 68a3039

Upload train_unet.py

Browse files
Files changed (1) hide show
  1. train_unet.py +270 -0
train_unet.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2 # OpenCV for image loading/processing
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset, DataLoader, random_split
10
+ import torchvision.transforms.functional as TF
11
+
12
+ import pytorch_lightning as pl
13
+ from pytorch_lightning.loggers import TensorBoardLogger
14
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
15
+
16
+ import albumentations as A
17
+ from albumentations.pytorch import ToTensorV2
18
+
19
+ import segmentation_models_pytorch as smp
20
+ from torchmetrics import JaccardIndex
21
+ from torchmetrics.segmentation import DiceScore
22
+
23
+ # --- Configuration ---
24
+ IMG_DIR = "derm_images_flat"
25
+ MASK_DIR = "derm_mask_images_flat"
26
+ MASK_SUFFIX = "_segmentation" # Part added to image name to get mask name
27
+ IMG_SIZE = (256, 256) # Resize images/masks to this size
28
+ BATCH_SIZE = 8
29
+ NUM_WORKERS = os.cpu_count() // 2
30
+ LEARNING_RATE = 1e-4 # Initial LR, will be tuned
31
+ MAX_EPOCHS = 5
32
+ VAL_SPLIT = 0.15 # 15% for validation
33
+ PATIENCE = 5 # For early stopping
34
+ ACCELERATOR = "gpu" if torch.cuda.is_available() else "cpu"
35
+ DEVICES = 1 if torch.cuda.is_available() else None
36
+ PRECISION = 16 if torch.cuda.is_available() else 32 # Use mixed precision if GPU supports it
37
+
38
+ # --- Dataset ---
39
+ class DermDataset(Dataset):
40
+ def __init__(self, image_paths, mask_dir, mask_suffix, transform=None):
41
+ self.image_paths = image_paths
42
+ self.mask_dir = mask_dir
43
+ self.mask_suffix = mask_suffix
44
+ self.transform = transform
45
+
46
+ def __len__(self):
47
+ return len(self.image_paths)
48
+
49
+ def __getitem__(self, idx):
50
+ img_path = self.image_paths[idx]
51
+ img_filename = os.path.basename(img_path)
52
+ img_name_part, img_ext = os.path.splitext(img_filename)
53
+
54
+ # Construct mask path - try common extensions like .png
55
+ mask_filename_base = f"{img_name_part}{self.mask_suffix}"
56
+ possible_mask_paths = glob.glob(os.path.join(self.mask_dir, f"{mask_filename_base}.*"))
57
+
58
+ if not possible_mask_paths:
59
+ raise FileNotFoundError(f"Mask not found for image: {img_path}. Tried pattern: {mask_filename_base}.* in {self.mask_dir}")
60
+
61
+ mask_path = possible_mask_paths[0] # Assume first found is the correct one
62
+
63
+ # Load image (ensure RGB)
64
+ image = cv2.imread(img_path)
65
+ if image is None:
66
+ raise IOError(f"Could not read image: {img_path}")
67
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+
69
+ # Load mask (ensure grayscale)
70
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
71
+ if mask is None:
72
+ raise IOError(f"Could not read mask: {mask_path}")
73
+
74
+ # Preprocess mask: ensure binary 0 or 1, add channel dim
75
+ mask = (mask > 128).astype(np.float32) # Threshold and convert to float
76
+ # mask = np.expand_dims(mask, axis=-1) # Add channel dim if needed by transforms/loss
77
+
78
+ # Apply transformations
79
+ if self.transform:
80
+ augmented = self.transform(image=image, mask=mask)
81
+ image = augmented['image']
82
+ mask = augmented['mask']
83
+ # Add channel dimension FOR THE MASK after albumentations if needed
84
+ # For BCEWithLogitsLoss with single class output, mask should be [B, 1, H, W]
85
+ mask = mask.unsqueeze(0) # Add channel dimension -> [1, H, W]
86
+
87
+ return {"image": image, "mask": mask}
88
+
89
+ # --- Transforms ---
90
+ def get_transforms(img_size, is_train=True):
91
+ if is_train:
92
+ # Augmentations for training
93
+ return A.Compose([
94
+ A.Resize(height=img_size[0], width=img_size[1]),
95
+ A.HorizontalFlip(p=0.5),
96
+ A.VerticalFlip(p=0.5),
97
+ A.RandomRotate90(p=0.5),
98
+ A.Affine(scale=(0.9, 1.1), translate_percent=0.0625, rotate=(-15, 15), p=0.5, cval=0),
99
+ A.OneOf([
100
+ A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05),
101
+ A.GridDistortion(p=0.5),
102
+ A.OpticalDistortion(distort_limit=0.5, p=1)
103
+ ], p=0.3),
104
+ A.RandomBrightnessContrast(p=0.3),
105
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # ImageNet stats
106
+ ToTensorV2(), # Converts image HWC->CHW, mask HW->HW (need to add C dim later)
107
+ ])
108
+ else:
109
+ # Validation/Test: Just resize and normalize
110
+ return A.Compose([
111
+ A.Resize(height=img_size[0], width=img_size[1]),
112
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
113
+ ToTensorV2(),
114
+ ])
115
+
116
+ # --- Lightning Module ---
117
+ class UNetLitModule(pl.LightningModule):
118
+ def __init__(self, learning_rate=1e-4):
119
+ super().__init__()
120
+ self.learning_rate = learning_rate
121
+ self.save_hyperparameters() # Saves args like learning_rate to hparams
122
+
123
+ # --- Model ---
124
+ # Using segmentation_models_pytorch
125
+ self.model = smp.Unet(
126
+ encoder_name="resnet34", # Choose backbone
127
+ encoder_weights="imagenet", # Use pretrained weights
128
+ in_channels=3, # Input channels (RGB)
129
+ classes=1, # Output channels (binary mask)
130
+ # activation='sigmoid' # Sigmoid usually applied *after* loss
131
+ )
132
+
133
+ # --- Loss Function ---
134
+ # BCEWithLogitsLoss is numerically stable for binary classification
135
+ self.loss_fn = nn.BCEWithLogitsLoss()
136
+
137
+ # --- Metrics ---
138
+ # Jaccard Index (IoU) for Segmentation
139
+ self.iou_metric = JaccardIndex(task="binary", threshold=0.5) # Threshold output probabilities
140
+
141
+ def forward(self, x):
142
+ return self.model(x)
143
+
144
+ def _common_step(self, batch, batch_idx, stage):
145
+ images = batch["image"]
146
+ masks = batch["mask"]
147
+
148
+ logits = self(images) # Model output (before activation)
149
+ loss = self.loss_fn(logits, masks)
150
+
151
+ # Calculate metrics
152
+ # Apply sigmoid before calculating metrics as they expect probabilities
153
+ preds = torch.sigmoid(logits)
154
+ iou = self.iou_metric(preds, masks.int()) # JaccardIndex expects integer masks
155
+
156
+ self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
157
+ self.log(f"{stage}_iou", iou, on_step=False, on_epoch=True, prog_bar=True, logger=True)
158
+
159
+ return loss
160
+
161
+ def training_step(self, batch, batch_idx):
162
+ return self._common_step(batch, batch_idx, "train")
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ return self._common_step(batch, batch_idx, "val")
166
+
167
+ def test_step(self, batch, batch_idx):
168
+ # Optional: If you have a separate test set
169
+ return self._common_step(batch, batch_idx, "test")
170
+
171
+ def configure_optimizers(self):
172
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
173
+ # Optional: Add a learning rate scheduler
174
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
175
+ # return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
176
+ return optimizer
177
+
178
+ # --- Main Training Script ---
179
+ if __name__ == "__main__":
180
+ pl.seed_everything(42) # for reproducibility
181
+
182
+ # --- Setup Data ---
183
+ all_image_paths = sorted(glob.glob(os.path.join(IMG_DIR, "*.*"))) # Find all image files
184
+ if not all_image_paths:
185
+ raise FileNotFoundError(f"No images found in {IMG_DIR}")
186
+
187
+ # Split data
188
+ n_total = len(all_image_paths)
189
+ n_val = int(n_total * VAL_SPLIT)
190
+ n_train = n_total - n_val
191
+
192
+ if n_train == 0 or n_val == 0:
193
+ raise ValueError(f"Train ({n_train}) or Val ({n_val}) split has 0 samples. Check VAL_SPLIT and dataset size.")
194
+
195
+ train_paths, val_paths = random_split(all_image_paths, [n_train, n_val])
196
+
197
+ train_dataset = DermDataset(list(train_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=True))
198
+ val_dataset = DermDataset(list(val_paths), MASK_DIR, MASK_SUFFIX, transform=get_transforms(IMG_SIZE, is_train=False))
199
+
200
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
201
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
202
+
203
+ print(f"Found {n_total} images. Training on {len(train_dataset)}, Validating on {len(val_dataset)}.")
204
+
205
+ # --- Initialize Model ---
206
+ # Instantiate with a placeholder LR first for LR finder
207
+ model = UNetLitModule(learning_rate=LEARNING_RATE)
208
+
209
+ # --- Callbacks ---
210
+ checkpoint_callback = ModelCheckpoint(
211
+ dirpath="checkpoints",
212
+ filename="unet-derm-{epoch:02d}-{val_iou:.4f}",
213
+ save_top_k=1,
214
+ verbose=True,
215
+ monitor="val_iou", # Save based on validation IoU
216
+ mode="max" # Maximize IoU
217
+ )
218
+ lr_monitor = LearningRateMonitor(logging_interval='step')
219
+ early_stop_callback = EarlyStopping(
220
+ monitor="val_iou", # Monitor validation IoU
221
+ patience=PATIENCE,
222
+ verbose=True,
223
+ mode="max" # Stop if IoU stops improving
224
+ )
225
+ logger = TensorBoardLogger("tb_logs", name="unet_derm_resnet34")
226
+
227
+
228
+ # --- Trainer ---
229
+ trainer = pl.Trainer(
230
+ logger=logger,
231
+ callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
232
+ max_epochs=MAX_EPOCHS,
233
+ accelerator=ACCELERATOR,
234
+ devices=DEVICES,
235
+ precision=PRECISION,
236
+ log_every_n_steps=10,
237
+ # deterministic=True, # Might slow down training
238
+ )
239
+
240
+ # --- Find Optimal Learning Rate ---
241
+ print("\nFinding optimal learning rate...")
242
+ tuner = pl.tuner.Tuner(trainer)
243
+ lr_finder_result = tuner.lr_find(model, train_dataloaders=train_loader, val_dataloaders=val_loader, num_training=100) # Run LR finder for 100 steps
244
+
245
+ # Inspect results and pick learning rate
246
+ fig = lr_finder_result.plot(suggest=True)
247
+ fig.show() # Display plot
248
+ suggested_lr = lr_finder_result.suggestion()
249
+
250
+ if suggested_lr is not None:
251
+ print(f"Suggested LR: {suggested_lr:.2e}")
252
+ model.hparams.learning_rate = suggested_lr # Update model's hparam
253
+ print(f"Using LR: {model.hparams.learning_rate:.2e}")
254
+ else:
255
+ print(f"LR finder did not suggest a rate. Using initial LR: {model.hparams.learning_rate:.2e}")
256
+
257
+
258
+ # --- Start Training ---
259
+ print("\nStarting training...")
260
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
261
+
262
+ print("\nTraining finished.")
263
+ print(f"Best model saved at: {checkpoint_callback.best_model_path}")
264
+
265
+ # --- Save final model state dict separately (optional, sometimes easier for inference) ---
266
+ final_model_path = "unet_derm_final_model.pth"
267
+ # Load best model before saving state dict
268
+ best_model = UNetLitModule.load_from_checkpoint(checkpoint_callback.best_model_path)
269
+ torch.save(best_model.model.state_dict(), final_model_path)
270
+ print(f"Final model state_dict saved to: {final_model_path}")