medsegfactory / dataset /dataset.py
JohnWeck's picture
Upload 2 files
7ce6244 verified
import torch
from torchvision import transforms
import numpy as np
import random
import os
import glob
import torch.nn.functional as F
from PIL import Image
from skimage.color import rgb2gray
def map_to_classes(label_array, max_pixel):
return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8)
def map_to_classes_isic(label_array):
image = np.where(label_array >= 0.5, 1, 0)
image = (image * 255.0).astype('uint8')
return image
def map_to_classes2(label_array):
image = np.where(label_array >= 0.5, 1, 0).astype('uint8')
return image
def center_crop(image, crop_size):
height, width = image.shape[:2]
crop_height, crop_width = crop_size
start_y = (height - crop_height) // 2
start_x = (width - crop_width) // 2
cropped_image = image[start_y:start_y + crop_height, start_x:start_x + crop_width]
return cropped_image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, tokenizer, size=256, center_crop=True, t_drop_rate=0.05,
i_drop_rate=0.05, ti_drop_rate=0.05):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.data = glob.glob(os.path.join(root, '*', '*.npz'))
self.img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
self.mask_transform = transforms.Compose([
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
self.max_pixels = {
'AMOS2022': 15,
'ACDC': 3,
'BUSI': 1,
'CVC-ClinicDB': 1,
'kvasir-seg': 1,
'LiTS2017': 2,
'KiTS2019': 2,
}
self.AMOS2022 = {1:'liver',2:'right kidney',3:'spleen',4:'pancreas',5:'aorta',6:'inferior vena cava',7:'right adrenal gland',8:'left adrenal gland',
9:'gall bladder',10:'esophagus',11:'stomach',12:'duodenum',13:'left kidney',14:'bladder',15:'prostate'}
self.ACDC = {1:'right ventricle',2:'myocardium',3:'left ventricle'}
self.LiTS2017 = {1:'liver',2:'liver tumor'}
self.KiTS2019 = {1:'kidney',2:'kidney tumor'}
self.aspect_ratios = [
(16, 9), # 16:9
(4, 3), # 4:3
(3, 2), # 3:2
(1, 1), # 1:1
(2, 1), # 2:1
(9, 16), # 9:16
(5, 4), # 5:4
(3, 4), # 3:4
(2, 3) # 2:3
]
def get_target_size(self, aspect_ratio, max_size=512):
h_ratio, w_ratio = aspect_ratio
if h_ratio > w_ratio:
height = max_size
# print(w_ratio, h_ratio)
width = int(max_size * w_ratio / h_ratio)
else:
width = max_size
height = int(max_size * h_ratio / w_ratio)
return (height, width)
def convert_to_rgb(self, image):
if len(image.shape) == 2:
rgb_img = np.stack((image, image, image), axis=-1)
elif len(image.shape) == 3 and image.shape[2] == 3:
rgb_img = image
else:
raise ValueError("不支持的图像格式")
return rgb_img
def __getitem__(self, idx):
path = self.data[idx]
name = path.split('/')[-2]
# read image
raw_image, ori_raw_mask = np.load(path)['image'], np.load(path)['label']
kinds = np.unique(ori_raw_mask)
raw_image, raw_mask = self.convert_to_rgb(raw_image), self.convert_to_rgb(ori_raw_mask)
# original size
# aspect = self.aspect_ratios[random.randint(0, len(self.aspect_ratios) - 1)]
# shape = self.get_target_size(aspect, self.size)
image_tensor = self.img_transform(raw_image)
raw_mask = raw_mask / self.max_pixels[name]
raw_mask = torch.from_numpy(raw_mask.transpose((2, 0, 1))).contiguous()
mask_tensor = self.mask_transform(raw_mask)
# image_tensor = transforms.Resize(size=shape)(image_tensor)
# mask_tensor = transforms.Resize(size=shape)(mask_tensor)
image = image_tensor.squeeze(dim=0)
mask = mask_tensor.squeeze(dim=0)
organ, kind = '', ''
tips = []
if name == 'AMOS2022':
organ = 'abdomen CT scans'
for k in kinds:
if k == 0:
pass
else:
tips.append(self.AMOS2022[k])
if len(tips) != 0:
random.shuffle(tips)
for tip in tips:
if kind == '':
kind = tip
else:
kind = kind + ',' + tip
elif name == 'ACDC':
organ = 'cardiovascular ventricle mri'
for k in kinds:
if k == 0:
pass
else:
tips.append(self.ACDC[k])
if len(tips) != 0:
random.shuffle(tips)
for tip in tips:
if kind == '':
kind = tip
else:
kind = kind + ',' + tip
elif name == 'BUSI':
organ = 'breast ultrasound'
if not kinds.any():
kind = 'normal'
else:
kind = 'breast tumor'
elif name == 'CVC-ClinicDB':
organ = 'polyp colonoscopy'
if not kinds.any():
kind = 'normal'
else:
kind = 'polyp'
elif name == 'kvasir-seg':
organ = 'polyp colonoscopy'
if not kinds.any():
kind = 'normal'
else:
kind = 'polyp'
elif name == 'LiTS2017':
organ = 'abdomen CT scans'
for k in kinds:
if k == 0:
pass
else:
tips.append(self.LiTS2017[k])
if len(tips) != 0:
random.shuffle(tips)
for tip in tips:
if kind == '':
kind = tip
else:
kind = kind + ',' + tip
elif name == 'KiTS2019':
organ = 'abdomen CT scans'
for k in kinds:
if k == 0:
pass
else:
tips.append(self.KiTS2019[k])
if len(tips) != 0:
random.shuffle(tips)
for tip in tips:
if kind == '':
kind = tip
else:
kind = kind + ',' + tip
if kind == '':
kind = 'no found'
img_text = f'a photo of {organ} image, with {kind}.'
mask_text = f'a photo of {organ} label, with {kind}.'
# if name == 'LiTS2017':
# print(kinds, img_text)
# drop
rand_num = random.random()
if rand_num < self.i_drop_rate:
img_text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
mask_text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
img_text = ""
mask_text = ""
# get text and tokenize
img_text_input_ids = self.tokenizer(
img_text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
mask_text_input_ids = self.tokenizer(
mask_text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return {
"image": image,
"mask": mask,
"img_text_input_ids": img_text_input_ids,
"mask_text_input_ids": mask_text_input_ids,
"raw_mask": ori_raw_mask,
"kind": kind
}
def __len__(self):
return len(self.data)
def collate_fn(data):
aspect_ratios = [
(16, 9), # 16:9
(4, 3), # 4:3
(3, 2), # 3:2
(1, 1), # 1:1
(2, 1), # 2:1
(9, 16), # 9:16
(5, 4), # 5:4
(3, 4), # 3:4
(2, 3) # 2:3
]
def get_target_size(aspect_ratio, max_size=256):
h_ratio, w_ratio = aspect_ratio
if h_ratio > w_ratio:
height = max_size
# print(w_ratio, h_ratio)
width = int(max_size * w_ratio / h_ratio)
else:
width = max_size
height = int(max_size * h_ratio / w_ratio)
return (height, width)
aspect = aspect_ratios[random.randint(0, len(aspect_ratios) - 1)]
shape = get_target_size(aspect, 512)
images = torch.stack([transforms.Resize(size=shape)(example["image"]) for example in data])
masks = torch.stack([transforms.Resize(size=shape)(example["mask"]) for example in data])
img_text_input_ids = torch.cat([example["img_text_input_ids"] for example in data], dim=0)
mask_text_input_ids = torch.cat([example["mask_text_input_ids"] for example in data], dim=0)
return {
"images": images,
"masks": masks,
"img_text_input_ids": img_text_input_ids,
"mask_text_input_ids": mask_text_input_ids,
}