|
import math |
|
import os |
|
from typing import Any, Callable, Optional, Tuple |
|
from monai import data, transforms as med |
|
from monai.data import load_decathlon_datalist |
|
import PIL.Image as PImage |
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
|
from torchvision.transforms import transforms |
|
from torch.utils.data import Dataset |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
interpolation = InterpolationMode.BICUBIC |
|
except: |
|
import PIL |
|
interpolation = PIL.Image.BICUBIC |
|
from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform |
|
import random |
|
|
|
|
|
def pil_loader(path): |
|
|
|
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') |
|
return img |
|
|
|
|
|
class ImageNetDataset(DatasetFolder): |
|
def __init__( |
|
self, |
|
imagenet_folder: str, |
|
train: bool, |
|
transform: Callable, |
|
is_valid_file: Optional[Callable[[str], bool]] = None, |
|
): |
|
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') |
|
super(ImageNetDataset, self).__init__( |
|
imagenet_folder, |
|
loader=pil_loader, |
|
extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
|
transform=transform, |
|
target_transform=None, is_valid_file=is_valid_file |
|
) |
|
|
|
self.samples = tuple(img for (img, label) in self.samples) |
|
self.targets = None |
|
|
|
def __getitem__(self, index: int) -> Any: |
|
img_file_path = self.samples[index] |
|
return self.transform(self.loader(img_file_path)) |
|
|
|
|
|
def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: |
|
""" |
|
You may need to modify this function to return your own dataset. |
|
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. |
|
Use dataset_path to build your image file path list. |
|
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. |
|
|
|
:param dataset_path: the folder of dataset |
|
:param input_size: the input size (image resolution) |
|
:return: the dataset used for pretraining |
|
""" |
|
trans_train = transforms.Compose([ |
|
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
dataset_path = os.path.abspath(dataset_path) |
|
for postfix in ('train', 'val'): |
|
if dataset_path.endswith(postfix): |
|
dataset_path = dataset_path[:-len(postfix)] |
|
|
|
dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) |
|
print_transform(trans_train, '[pre-train]') |
|
return dataset_train |
|
|
|
|
|
def build_meddataset_to_pretrain(dataset_path, input_size) -> Dataset: |
|
""" |
|
You may need to modify this function to return your own dataset. |
|
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. |
|
Use dataset_path to build your image file path list. |
|
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. |
|
|
|
:param dataset_path: the folder of dataset |
|
:param input_size: the input size (image resolution) |
|
:return: the dataset used for pretraining |
|
""" |
|
trans_train = transforms.Compose([ |
|
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
dataset_path = os.path.abspath(dataset_path) |
|
|
|
|
|
dataset_train = MedicalDataSets(base_dir=dataset_path, transform=trans_train) |
|
print_transform(trans_train, '[pre-train]') |
|
return dataset_train |
|
|
|
|
|
|
|
class MedicalDataSets(Dataset): |
|
def __init__( |
|
self, |
|
base_dir=None, |
|
transform=None, |
|
): |
|
self._base_dir = base_dir |
|
self.sample_list = [] |
|
self.sample_list = os.listdir(self._base_dir) |
|
self.transform = transform |
|
print("total {}".format(len(self.sample_list))) |
|
|
|
def __len__(self): |
|
return len(self.sample_list) |
|
|
|
def __getitem__(self, idx): |
|
case = self.sample_list[idx] |
|
img = PImage.open(os.path.join(self._base_dir, case)).convert('RGB') |
|
aug = self.transform(img) |
|
return aug |
|
|
|
def print_transform(transform, s): |
|
print(f'Transform {s} = ') |
|
for t in transform.transforms: |
|
print(t) |
|
print('---------------------------\n') |
|
|
|
|
|
class Sampler(torch.utils.data.Sampler): |
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): |
|
if num_replicas is None: |
|
if not torch.distributed.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
num_replicas = torch.distributed.get_world_size() |
|
if rank is None: |
|
if not torch.distributed.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
rank = torch.distributed.get_rank() |
|
self.shuffle = shuffle |
|
self.make_even = make_even |
|
self.dataset = dataset |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.epoch = 0 |
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
|
self.total_size = self.num_samples * self.num_replicas |
|
indices = list(range(len(self.dataset))) |
|
self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas]) |
|
|
|
def __iter__(self): |
|
if self.shuffle: |
|
g = torch.Generator() |
|
g.manual_seed(self.epoch) |
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
else: |
|
indices = list(range(len(self.dataset))) |
|
if self.make_even: |
|
if len(indices) < self.total_size: |
|
if self.total_size - len(indices) < len(indices): |
|
indices += indices[: (self.total_size - len(indices))] |
|
else: |
|
extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) |
|
indices += [indices[ids] for ids in extra_ids] |
|
assert len(indices) == self.total_size |
|
indices = indices[self.rank : self.total_size : self.num_replicas] |
|
self.num_samples = len(indices) |
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
class RandScaleCropdPlusScaleByMidDimSampled(MapTransform): |
|
def __init__(self, keys, mode='area', max_size=128,allow_missing_keys=False,num_samples=4,max_radio=0.8,min_radio=0.5): |
|
self.keys = keys |
|
self.mode = mode |
|
self.allow_missing_keys = allow_missing_keys |
|
self.max_size=max_size |
|
self.num_samples = num_samples |
|
self.max_radio=max_radio |
|
self.min_radio=min_radio |
|
|
|
def __call__(self, data): |
|
outputs = [] |
|
for i in range(self.num_samples): |
|
random_number = round(random.uniform(self.min_radio, self.max_radio), 2) |
|
_data = dict(data) |
|
for key in self.keys: |
|
cropper= med.RandScaleCropd(keys=[key],roi_scale=random_number) |
|
_data[key] = cropper(_data)[key] |
|
ct_tensor = _data[key] |
|
sorted_numbers = sorted(ct_tensor.shape[1:]) |
|
scale_factor = self.max_size / sorted_numbers[1] |
|
new_size = [int(d * scale_factor) |
|
for d in ct_tensor.shape[1:]] |
|
|
|
resizer = med.Resized(keys=[key], |
|
spatial_size=new_size, |
|
mode=self.mode, |
|
allow_missing_keys=self.allow_missing_keys) |
|
_data[key] = resizer(_data)[key] |
|
|
|
outputs.append(_data) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
def get_loader(data_dir, size): |
|
datalist_json = os.path.join(data_dir, "dataset.json") |
|
train_transform = med.Compose( |
|
[ |
|
med.LoadImaged(keys=["image"], allow_missing_keys=True), |
|
med.AddChanneld(keys=["image"], allow_missing_keys=True), |
|
med.Orientationd(keys=["image"], axcodes="RAS", allow_missing_keys=True), |
|
med.Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear", allow_missing_keys=True), |
|
med.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), |
|
med.CropForegroundd(keys=["image"], source_key="image", allow_missing_keys=True), |
|
med.SpatialPadd(keys=["image"], spatial_size=(size, size, size), mode='constant'), |
|
med.RandCropByPosNegLabeld( |
|
spatial_size=(size, size, size), |
|
keys=["image"], |
|
label_key="image", |
|
pos=1, |
|
neg=0, |
|
num_samples=4, |
|
), |
|
med.RandFlipd(keys=["image"], |
|
prob=0.2, |
|
spatial_axis=0), |
|
med.RandFlipd(keys=["image"], |
|
prob=0.2, |
|
spatial_axis=1), |
|
med.RandFlipd(keys=["image"], |
|
prob=0.1, |
|
spatial_axis=2), |
|
med.ToTensord(keys=["image"]), |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir) |
|
|
|
|
|
|
|
train_ds= data.CacheNTransDataset(data=datalist, transform=train_transform, cache_n_trans=6, cache_dir="/fenghetang/3d/pretrain/MM/cache_dataset") |
|
return train_ds |
|
|
|
|
|
|
|
|