Spaces:
Running
Running
| """Functions for training and running segmentation.""" | |
| import math | |
| import os | |
| import time | |
| import click | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import scipy.signal | |
| import skimage.draw | |
| import torch | |
| import torchvision | |
| import tqdm | |
| import echonet | |
| def run( | |
| data_dir=None, | |
| output=None, | |
| model_name="deeplabv3_resnet50", | |
| pretrained=False, | |
| weights=None, | |
| run_test=False, | |
| save_video=False, | |
| num_epochs=50, | |
| lr=1e-5, | |
| weight_decay=1e-5, | |
| lr_step_period=None, | |
| num_train_patients=None, | |
| num_workers=4, | |
| batch_size=20, | |
| device=None, | |
| seed=0, | |
| ): | |
| """Trains/tests segmentation model. | |
| Args: | |
| data_dir (str, optional): Directory containing dataset. Defaults to | |
| `echonet.config.DATA_DIR`. | |
| output (str, optional): Directory to place outputs. Defaults to | |
| output/segmentation/<model_name>_<pretrained/random>/. | |
| model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'', | |
| ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101'' | |
| (options are torchvision.models.segmentation.<model_name>) | |
| Defaults to ``deeplabv3_resnet50''. | |
| pretrained (bool, optional): Whether to use pretrained weights for model | |
| Defaults to False. | |
| weights (str, optional): Path to checkpoint containing weights to | |
| initialize model. Defaults to None. | |
| run_test (bool, optional): Whether or not to run on test. | |
| Defaults to False. | |
| save_video (bool, optional): Whether to save videos with segmentations. | |
| Defaults to False. | |
| num_epochs (int, optional): Number of epochs during training | |
| Defaults to 50. | |
| lr (float, optional): Learning rate for SGD | |
| Defaults to 1e-5. | |
| weight_decay (float, optional): Weight decay for SGD | |
| Defaults to 0. | |
| lr_step_period (int or None, optional): Period of learning rate decay | |
| (learning rate is decayed by a multiplicative factor of 0.1) | |
| Defaults to math.inf (never decay learning rate). | |
| num_train_patients (int or None, optional): Number of training patients | |
| for ablations. Defaults to all patients. | |
| num_workers (int, optional): Number of subprocesses to use for data | |
| loading. If 0, the data will be loaded in the main process. | |
| Defaults to 4. | |
| device (str or None, optional): Name of device to run on. Options from | |
| https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device | |
| Defaults to ``cuda'' if available, and ``cpu'' otherwise. | |
| batch_size (int, optional): Number of samples to load per batch | |
| Defaults to 20. | |
| seed (int, optional): Seed for random number generator. Defaults to 0. | |
| """ | |
| # Seed RNGs | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # Set default output directory | |
| if output is None: | |
| output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random")) | |
| os.makedirs(output, exist_ok=True) | |
| # Set device for computations | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Set up model | |
| model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False) | |
| model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) # change number of outputs to 1 | |
| if device.type == "cuda": | |
| model = torch.nn.DataParallel(model) | |
| model.to(device) | |
| if weights is not None: | |
| checkpoint = torch.load(weights) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| # Set up optimizer | |
| optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) | |
| if lr_step_period is None: | |
| lr_step_period = math.inf | |
| scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) | |
| # Compute mean and std | |
| mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) | |
| tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] | |
| kwargs = {"target_type": tasks, | |
| "mean": mean, | |
| "std": std | |
| } | |
| # Set up datasets and dataloaders | |
| dataset = {} | |
| dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs) | |
| if num_train_patients is not None and len(dataset["train"]) > num_train_patients: | |
| # Subsample patients (used for ablation experiment) | |
| indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) | |
| dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) | |
| dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) | |
| # Run training and testing loops | |
| with open(os.path.join(output, "log.csv"), "a") as f: | |
| epoch_resume = 0 | |
| bestLoss = float("inf") | |
| try: | |
| # Attempt to load checkpoint | |
| checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| optim.load_state_dict(checkpoint['opt_dict']) | |
| scheduler.load_state_dict(checkpoint['scheduler_dict']) | |
| epoch_resume = checkpoint["epoch"] + 1 | |
| bestLoss = checkpoint["best_loss"] | |
| f.write("Resuming from epoch {}\n".format(epoch_resume)) | |
| except FileNotFoundError: | |
| f.write("Starting run from scratch\n") | |
| for epoch in range(epoch_resume, num_epochs): | |
| print("Epoch #{}".format(epoch), flush=True) | |
| for phase in ['train', 'val']: | |
| start_time = time.time() | |
| for i in range(torch.cuda.device_count()): | |
| torch.cuda.reset_peak_memory_stats(i) | |
| ds = dataset[phase] | |
| dataloader = torch.utils.data.DataLoader( | |
| ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) | |
| loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device) | |
| overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum()) | |
| large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum()) | |
| small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum()) | |
| f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, | |
| phase, | |
| loss, | |
| overall_dice, | |
| large_dice, | |
| small_dice, | |
| time.time() - start_time, | |
| large_inter.size, | |
| sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), | |
| sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), | |
| batch_size)) | |
| f.flush() | |
| scheduler.step() | |
| # Save checkpoint | |
| save = { | |
| 'epoch': epoch, | |
| 'state_dict': model.state_dict(), | |
| 'best_loss': bestLoss, | |
| 'loss': loss, | |
| 'opt_dict': optim.state_dict(), | |
| 'scheduler_dict': scheduler.state_dict(), | |
| } | |
| torch.save(save, os.path.join(output, "checkpoint.pt")) | |
| if loss < bestLoss: | |
| torch.save(save, os.path.join(output, "best.pt")) | |
| bestLoss = loss | |
| # Load best weights | |
| if num_epochs != 0: | |
| checkpoint = torch.load(os.path.join(output, "best.pt")) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) | |
| if run_test: | |
| # Run on validation and test | |
| for split in ["val", "test"]: | |
| dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs) | |
| dataloader = torch.utils.data.DataLoader(dataset, | |
| batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) | |
| loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device) | |
| overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter) | |
| large_dice = 2 * large_inter / (large_union + large_inter) | |
| small_dice = 2 * small_inter / (small_union + small_inter) | |
| with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g: | |
| g.write("Filename, Overall, Large, Small\n") | |
| for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice): | |
| g.write("{},{},{},{}\n".format(filename, overall, large, small)) | |
| f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient))) | |
| f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient))) | |
| f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient))) | |
| f.flush() | |
| # Saving videos with segmentations | |
| dataset = echonet.datasets.Echo(root=data_dir, split="test", | |
| target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate | |
| mean=mean, std=std, # Normalization | |
| length=None, max_length=None, period=1 # Take all frames | |
| ) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn) | |
| # Save videos with segmentation | |
| if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames): | |
| # Only run if missing videos | |
| model.eval() | |
| os.makedirs(os.path.join(output, "videos"), exist_ok=True) | |
| os.makedirs(os.path.join(output, "size"), exist_ok=True) | |
| echonet.utils.latexify() | |
| with torch.no_grad(): | |
| with open(os.path.join(output, "size.csv"), "w") as g: | |
| g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n") | |
| for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader): | |
| # Run segmentation model on blocks of frames one-by-one | |
| # The whole concatenated video may be too long to run together | |
| y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)]) | |
| start = 0 | |
| x = x.numpy() | |
| for (i, (filename, offset)) in enumerate(zip(filenames, length)): | |
| # Extract one video and segmentation predictions | |
| video = x[start:(start + offset), ...] | |
| logit = y[start:(start + offset), 0, :, :] | |
| # Un-normalize video | |
| video *= std.reshape(1, 3, 1, 1) | |
| video += mean.reshape(1, 3, 1, 1) | |
| # Get frames, channels, height, and width | |
| f, c, h, w = video.shape # pylint: disable=W0612 | |
| assert c == 3 | |
| # Put two copies of the video side by side | |
| video = np.concatenate((video, video), 3) | |
| # If a pixel is in the segmentation, saturate blue channel | |
| # Leave alone otherwise | |
| video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:]) # pylint: disable=E1111 | |
| # Add blank canvas under pair of videos | |
| video = np.concatenate((video, np.zeros_like(video)), 2) | |
| # Compute size of segmentation per frame | |
| size = (logit > 0).sum((1, 2)) | |
| # Identify systole frames with peak detection | |
| trim_min = sorted(size)[round(len(size) ** 0.05)] | |
| trim_max = sorted(size)[round(len(size) ** 0.95)] | |
| trim_range = trim_max - trim_min | |
| systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0]) | |
| # Write sizes and frames to file | |
| for (frame, s) in enumerate(size): | |
| g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0)) | |
| # Plot sizes | |
| fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3)) | |
| plt.scatter(np.arange(size.shape[0]) / 50, size, s=1) | |
| ylim = plt.ylim() | |
| for s in systole: | |
| plt.plot(np.array([s, s]) / 50, ylim, linewidth=1) | |
| plt.ylim(ylim) | |
| plt.title(os.path.splitext(filename)[0]) | |
| plt.xlabel("Seconds") | |
| plt.ylabel("Size (pixels)") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf")) | |
| plt.close(fig) | |
| # Normalize size to [0, 1] | |
| size -= size.min() | |
| size = size / size.max() | |
| size = 1 - size | |
| # Iterate the frames in this video | |
| for (f, s) in enumerate(size): | |
| # On all frames, mark a pixel for the size of the frame | |
| video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255. | |
| if f in systole: | |
| # If frame is computer-selected systole, mark with a line | |
| video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255. | |
| def dash(start, stop, on=10, off=10): | |
| buf = [] | |
| x = start | |
| while x < stop: | |
| buf.extend(range(x, x + on)) | |
| x += on | |
| x += off | |
| buf = np.array(buf) | |
| buf = buf[buf < stop] | |
| return buf | |
| d = dash(115, 224) | |
| if f == large_index[i]: | |
| # If frame is human-selected diastole, mark with green dashed line on all frames | |
| video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1)) | |
| if f == small_index[i]: | |
| # If frame is human-selected systole, mark with red dashed line on all frames | |
| video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1)) | |
| # Get pixels for a circle centered on the pixel | |
| r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1) | |
| # On the frame that's being shown, put a circle over the pixel | |
| video[f, :, r, c] = 255. | |
| # Rearrange dimensions and save | |
| video = video.transpose(1, 0, 2, 3) | |
| video = video.astype(np.uint8) | |
| echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50) | |
| # Move to next video | |
| start += offset | |
| def run_epoch(model, dataloader, train, optim, device): | |
| """Run one epoch of training/evaluation for segmentation. | |
| Args: | |
| model (torch.nn.Module): Model to train/evaulate. | |
| dataloder (torch.utils.data.DataLoader): Dataloader for dataset. | |
| train (bool): Whether or not to train model. | |
| optim (torch.optim.Optimizer): Optimizer | |
| device (torch.device): Device to run on | |
| """ | |
| total = 0. | |
| n = 0 | |
| pos = 0 | |
| neg = 0 | |
| pos_pix = 0 | |
| neg_pix = 0 | |
| model.train(train) | |
| large_inter = 0 | |
| large_union = 0 | |
| small_inter = 0 | |
| small_union = 0 | |
| large_inter_list = [] | |
| large_union_list = [] | |
| small_inter_list = [] | |
| small_union_list = [] | |
| with torch.set_grad_enabled(train): | |
| with tqdm.tqdm(total=len(dataloader)) as pbar: | |
| for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader: | |
| # Count number of pixels in/out of human segmentation | |
| pos += (large_trace == 1).sum().item() | |
| pos += (small_trace == 1).sum().item() | |
| neg += (large_trace == 0).sum().item() | |
| neg += (small_trace == 0).sum().item() | |
| # Count number of pixels in/out of computer segmentation | |
| pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy() | |
| pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy() | |
| neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy() | |
| neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy() | |
| # Run prediction for diastolic frames and compute loss | |
| large_frame = large_frame.to(device) | |
| large_trace = large_trace.to(device) | |
| y_large = model(large_frame)["out"] | |
| loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum") | |
| # Compute pixel intersection and union between human and computer segmentations | |
| large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
| large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
| large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
| large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
| # Run prediction for systolic frames and compute loss | |
| small_frame = small_frame.to(device) | |
| small_trace = small_trace.to(device) | |
| y_small = model(small_frame)["out"] | |
| loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum") | |
| # Compute pixel intersection and union between human and computer segmentations | |
| small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
| small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
| small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
| small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
| # Take gradient step if training | |
| loss = (loss_large + loss_small) / 2 | |
| if train: | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| # Accumulate losses and compute baselines | |
| total += loss.item() | |
| n += large_trace.size(0) | |
| p = pos / (pos + neg) | |
| p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2) | |
| # Show info on process bar | |
| pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter))) | |
| pbar.update() | |
| large_inter_list = np.array(large_inter_list) | |
| large_union_list = np.array(large_union_list) | |
| small_inter_list = np.array(small_inter_list) | |
| small_union_list = np.array(small_union_list) | |
| return (total / n / 112 / 112, | |
| large_inter_list, | |
| large_union_list, | |
| small_inter_list, | |
| small_union_list, | |
| ) | |
| def _video_collate_fn(x): | |
| """Collate function for Pytorch dataloader to merge multiple videos. | |
| This function should be used in a dataloader for a dataset that returns | |
| a video as the first element, along with some (non-zero) tuple of | |
| targets. Then, the input x is a list of tuples: | |
| - x[i][0] is the i-th video in the batch | |
| - x[i][1] are the targets for the i-th video | |
| This function returns a 3-tuple: | |
| - The first element is the videos concatenated along the frames | |
| dimension. This is done so that videos of different lengths can be | |
| processed together (tensors cannot be "jagged", so we cannot have | |
| a dimension for video, and another for frames). | |
| - The second element is contains the targets with no modification. | |
| - The third element is a list of the lengths of the videos in frames. | |
| """ | |
| video, target = zip(*x) # Extract the videos and targets | |
| # ``video'' is a tuple of length ``batch_size'' | |
| # Each element has shape (channels=3, frames, height, width) | |
| # height and width are expected to be the same across videos, but | |
| # frames can be different. | |
| # ``target'' is also a tuple of length ``batch_size'' | |
| # Each element is a tuple of the targets for the item. | |
| i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames | |
| # This contatenates the videos along the the frames dimension (basically | |
| # playing the videos one after another). The frames dimension is then | |
| # moved to be first. | |
| # Resulting shape is (total frames, channels=3, height, width) | |
| video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1)) | |
| # Swap dimensions (approximately a transpose) | |
| # Before: target[i][j] is the j-th target of element i | |
| # After: target[i][j] is the i-th target of element j | |
| target = zip(*target) | |
| return video, target, i | |