MambaMIM / utils /med_dataset.py
FengheTan9's picture
Upload folder using huggingface_hub
6da2a44 verified
import math
import os
from typing import Any, Callable, Optional, Tuple
from monai import data, transforms as med
from monai.data import load_decathlon_datalist
import PIL.Image as PImage
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import transforms
from torch.utils.data import Dataset
import torch
import numpy as np
import cv2
try:
from torchvision.transforms import InterpolationMode
interpolation = InterpolationMode.BICUBIC
except:
import PIL
interpolation = PIL.Image.BICUBIC
from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform
import random
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
return img
class ImageNetDataset(DatasetFolder):
def __init__(
self,
imagenet_folder: str,
train: bool,
transform: Callable,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
super(ImageNetDataset, self).__init__(
imagenet_folder,
loader=pil_loader,
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=None, is_valid_file=is_valid_file
)
self.samples = tuple(img for (img, label) in self.samples)
self.targets = None # this is self-supervised learning so we don't need labels
def __getitem__(self, index: int) -> Any:
img_file_path = self.samples[index]
return self.transform(self.loader(img_file_path))
def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset:
"""
You may need to modify this function to return your own dataset.
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
Use dataset_path to build your image file path list.
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
:param dataset_path: the folder of dataset
:param input_size: the input size (image resolution)
:return: the dataset used for pretraining
"""
trans_train = transforms.Compose([
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
dataset_path = os.path.abspath(dataset_path)
for postfix in ('train', 'val'):
if dataset_path.endswith(postfix):
dataset_path = dataset_path[:-len(postfix)]
dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True)
print_transform(trans_train, '[pre-train]')
return dataset_train
def build_meddataset_to_pretrain(dataset_path, input_size) -> Dataset:
"""
You may need to modify this function to return your own dataset.
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
Use dataset_path to build your image file path list.
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
:param dataset_path: the folder of dataset
:param input_size: the input size (image resolution)
:return: the dataset used for pretraining
"""
trans_train = transforms.Compose([
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
dataset_path = os.path.abspath(dataset_path)
dataset_train = MedicalDataSets(base_dir=dataset_path, transform=trans_train)
print_transform(trans_train, '[pre-train]')
return dataset_train
class MedicalDataSets(Dataset):
def __init__(
self,
base_dir=None,
transform=None,
):
self._base_dir = base_dir
self.sample_list = []
self.sample_list = os.listdir(self._base_dir)
self.transform = transform
print("total {}".format(len(self.sample_list)))
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
case = self.sample_list[idx]
img = PImage.open(os.path.join(self._base_dir, case)).convert('RGB')
aug = self.transform(img)
return aug
def print_transform(transform, s):
print(f'Transform {s} = ')
for t in transform.transforms:
print(t)
print('---------------------------\n')
class Sampler(torch.utils.data.Sampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.shuffle = shuffle
self.make_even = make_even
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
if self.make_even:
if len(indices) < self.total_size:
if self.total_size - len(indices) < len(indices):
indices += indices[: (self.total_size - len(indices))]
else:
extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
indices += [indices[ids] for ids in extra_ids]
assert len(indices) == self.total_size
indices = indices[self.rank : self.total_size : self.num_replicas]
self.num_samples = len(indices)
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class RandScaleCropdPlusScaleByMidDimSampled(MapTransform):
def __init__(self, keys, mode='area', max_size=128,allow_missing_keys=False,num_samples=4,max_radio=0.8,min_radio=0.5):
self.keys = keys
self.mode = mode
self.allow_missing_keys = allow_missing_keys
self.max_size=max_size
self.num_samples = num_samples
self.max_radio=max_radio
self.min_radio=min_radio
def __call__(self, data):
outputs = []
for i in range(self.num_samples):
random_number = round(random.uniform(self.min_radio, self.max_radio), 2)
_data = dict(data)
for key in self.keys:
cropper= med.RandScaleCropd(keys=[key],roi_scale=random_number)
_data[key] = cropper(_data)[key]
ct_tensor = _data[key]
sorted_numbers = sorted(ct_tensor.shape[1:])
scale_factor = self.max_size / sorted_numbers[1]
new_size = [int(d * scale_factor)
for d in ct_tensor.shape[1:]]
resizer = med.Resized(keys=[key],
spatial_size=new_size,
mode=self.mode,
allow_missing_keys=self.allow_missing_keys)
_data[key] = resizer(_data)[key]
outputs.append(_data)
return outputs
def get_loader(data_dir, size):
datalist_json = os.path.join(data_dir, "dataset.json")
train_transform = med.Compose(
[
med.LoadImaged(keys=["image"], allow_missing_keys=True),
med.AddChanneld(keys=["image"], allow_missing_keys=True),
med.Orientationd(keys=["image"], axcodes="RAS", allow_missing_keys=True),
med.Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear", allow_missing_keys=True),
med.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
med.CropForegroundd(keys=["image"], source_key="image", allow_missing_keys=True),
med.SpatialPadd(keys=["image"], spatial_size=(size, size, size), mode='constant'),
med.RandCropByPosNegLabeld(
spatial_size=(size, size, size),
keys=["image"],
label_key="image",
pos=1,
neg=0,
num_samples=4,
),
med.RandFlipd(keys=["image"],
prob=0.2,
spatial_axis=0),
med.RandFlipd(keys=["image"],
prob=0.2,
spatial_axis=1),
med.RandFlipd(keys=["image"],
prob=0.1,
spatial_axis=2),
med.ToTensord(keys=["image"]),
])
# val_transform = transforms.Compose(
# [
# transforms.LoadImaged(keys=["image", "label"]),
# transforms.AddChanneld(keys=["image", "label"]),
# transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
# transforms.Spacingd(
# keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")
# ),
# transforms.ScaleIntensityRanged(
# keys=["image"], a_min=-175.0, a_max=250.0, b_min=0.0, b_max=1.0, clip=True
# ),
# transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
# transforms.ToTensord(keys=["image", "label"]),
# ]
# )
datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir)
# train_ds = data.Dataset(data=datalist, transform=train_transform)
# train_ds = data.CacheDataset(data=datalist, transform=train_transform)
# train_ds = data.SmartCacheDataset(data=datalist, transform=train_transform, replace_rate=0.7, cache_num=256, num_init_workers=4, num_replace_workers=4)
train_ds= data.CacheNTransDataset(data=datalist, transform=train_transform, cache_n_trans=6, cache_dir="/fenghetang/3d/pretrain/MM/cache_dataset")
return train_ds