Spaces:
Running
Running
import torch | |
from torchvision import transforms | |
import numpy as np | |
import random | |
import os | |
import glob | |
import torch.nn.functional as F | |
from PIL import Image | |
from skimage.color import rgb2gray | |
def map_to_classes(label_array, max_pixel): | |
return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8) | |
def map_to_classes_isic(label_array): | |
image = np.where(label_array >= 0.5, 1, 0) | |
image = (image * 255.0).astype('uint8') | |
return image | |
def map_to_classes2(label_array): | |
image = np.where(label_array >= 0.5, 1, 0).astype('uint8') | |
return image | |
def center_crop(image, crop_size): | |
height, width = image.shape[:2] | |
crop_height, crop_width = crop_size | |
start_y = (height - crop_height) // 2 | |
start_x = (width - crop_width) // 2 | |
cropped_image = image[start_y:start_y + crop_height, start_x:start_x + crop_width] | |
return cropped_image | |
class MyDataset(torch.utils.data.Dataset): | |
def __init__(self, root, tokenizer, size=256, center_crop=True, t_drop_rate=0.05, | |
i_drop_rate=0.05, ti_drop_rate=0.05): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.size = size | |
self.center_crop = center_crop | |
self.i_drop_rate = i_drop_rate | |
self.t_drop_rate = t_drop_rate | |
self.ti_drop_rate = ti_drop_rate | |
self.data = glob.glob(os.path.join(root, '*', '*.npz')) | |
self.img_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
]) | |
self.mask_transform = transforms.Compose([ | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
]) | |
self.max_pixels = { | |
'AMOS2022': 15, | |
'ACDC': 3, | |
'BUSI': 1, | |
'CVC-ClinicDB': 1, | |
'kvasir-seg': 1, | |
'LiTS2017': 2, | |
'KiTS2019': 2, | |
} | |
self.AMOS2022 = {1:'liver',2:'right kidney',3:'spleen',4:'pancreas',5:'aorta',6:'inferior vena cava',7:'right adrenal gland',8:'left adrenal gland', | |
9:'gall bladder',10:'esophagus',11:'stomach',12:'duodenum',13:'left kidney',14:'bladder',15:'prostate'} | |
self.ACDC = {1:'right ventricle',2:'myocardium',3:'left ventricle'} | |
self.LiTS2017 = {1:'liver',2:'liver tumor'} | |
self.KiTS2019 = {1:'kidney',2:'kidney tumor'} | |
self.aspect_ratios = [ | |
(16, 9), # 16:9 | |
(4, 3), # 4:3 | |
(3, 2), # 3:2 | |
(1, 1), # 1:1 | |
(2, 1), # 2:1 | |
(9, 16), # 9:16 | |
(5, 4), # 5:4 | |
(3, 4), # 3:4 | |
(2, 3) # 2:3 | |
] | |
def get_target_size(self, aspect_ratio, max_size=512): | |
h_ratio, w_ratio = aspect_ratio | |
if h_ratio > w_ratio: | |
height = max_size | |
# print(w_ratio, h_ratio) | |
width = int(max_size * w_ratio / h_ratio) | |
else: | |
width = max_size | |
height = int(max_size * h_ratio / w_ratio) | |
return (height, width) | |
def convert_to_rgb(self, image): | |
if len(image.shape) == 2: | |
rgb_img = np.stack((image, image, image), axis=-1) | |
elif len(image.shape) == 3 and image.shape[2] == 3: | |
rgb_img = image | |
else: | |
raise ValueError("不支持的图像格式") | |
return rgb_img | |
def __getitem__(self, idx): | |
path = self.data[idx] | |
name = path.split('/')[-2] | |
# read image | |
raw_image, ori_raw_mask = np.load(path)['image'], np.load(path)['label'] | |
kinds = np.unique(ori_raw_mask) | |
raw_image, raw_mask = self.convert_to_rgb(raw_image), self.convert_to_rgb(ori_raw_mask) | |
# original size | |
# aspect = self.aspect_ratios[random.randint(0, len(self.aspect_ratios) - 1)] | |
# shape = self.get_target_size(aspect, self.size) | |
image_tensor = self.img_transform(raw_image) | |
raw_mask = raw_mask / self.max_pixels[name] | |
raw_mask = torch.from_numpy(raw_mask.transpose((2, 0, 1))).contiguous() | |
mask_tensor = self.mask_transform(raw_mask) | |
# image_tensor = transforms.Resize(size=shape)(image_tensor) | |
# mask_tensor = transforms.Resize(size=shape)(mask_tensor) | |
image = image_tensor.squeeze(dim=0) | |
mask = mask_tensor.squeeze(dim=0) | |
organ, kind = '', '' | |
tips = [] | |
if name == 'AMOS2022': | |
organ = 'abdomen CT scans' | |
for k in kinds: | |
if k == 0: | |
pass | |
else: | |
tips.append(self.AMOS2022[k]) | |
if len(tips) != 0: | |
random.shuffle(tips) | |
for tip in tips: | |
if kind == '': | |
kind = tip | |
else: | |
kind = kind + ',' + tip | |
elif name == 'ACDC': | |
organ = 'cardiovascular ventricle mri' | |
for k in kinds: | |
if k == 0: | |
pass | |
else: | |
tips.append(self.ACDC[k]) | |
if len(tips) != 0: | |
random.shuffle(tips) | |
for tip in tips: | |
if kind == '': | |
kind = tip | |
else: | |
kind = kind + ',' + tip | |
elif name == 'BUSI': | |
organ = 'breast ultrasound' | |
if not kinds.any(): | |
kind = 'normal' | |
else: | |
kind = 'breast tumor' | |
elif name == 'CVC-ClinicDB': | |
organ = 'polyp colonoscopy' | |
if not kinds.any(): | |
kind = 'normal' | |
else: | |
kind = 'polyp' | |
elif name == 'kvasir-seg': | |
organ = 'polyp colonoscopy' | |
if not kinds.any(): | |
kind = 'normal' | |
else: | |
kind = 'polyp' | |
elif name == 'LiTS2017': | |
organ = 'abdomen CT scans' | |
for k in kinds: | |
if k == 0: | |
pass | |
else: | |
tips.append(self.LiTS2017[k]) | |
if len(tips) != 0: | |
random.shuffle(tips) | |
for tip in tips: | |
if kind == '': | |
kind = tip | |
else: | |
kind = kind + ',' + tip | |
elif name == 'KiTS2019': | |
organ = 'abdomen CT scans' | |
for k in kinds: | |
if k == 0: | |
pass | |
else: | |
tips.append(self.KiTS2019[k]) | |
if len(tips) != 0: | |
random.shuffle(tips) | |
for tip in tips: | |
if kind == '': | |
kind = tip | |
else: | |
kind = kind + ',' + tip | |
if kind == '': | |
kind = 'no found' | |
img_text = f'a photo of {organ} image, with {kind}.' | |
mask_text = f'a photo of {organ} label, with {kind}.' | |
# if name == 'LiTS2017': | |
# print(kinds, img_text) | |
# drop | |
rand_num = random.random() | |
if rand_num < self.i_drop_rate: | |
img_text = "" | |
elif rand_num < (self.i_drop_rate + self.t_drop_rate): | |
mask_text = "" | |
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): | |
img_text = "" | |
mask_text = "" | |
# get text and tokenize | |
img_text_input_ids = self.tokenizer( | |
img_text, | |
max_length=self.tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
).input_ids | |
mask_text_input_ids = self.tokenizer( | |
mask_text, | |
max_length=self.tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
).input_ids | |
return { | |
"image": image, | |
"mask": mask, | |
"img_text_input_ids": img_text_input_ids, | |
"mask_text_input_ids": mask_text_input_ids, | |
"raw_mask": ori_raw_mask, | |
"kind": kind | |
} | |
def __len__(self): | |
return len(self.data) | |
def collate_fn(data): | |
aspect_ratios = [ | |
(16, 9), # 16:9 | |
(4, 3), # 4:3 | |
(3, 2), # 3:2 | |
(1, 1), # 1:1 | |
(2, 1), # 2:1 | |
(9, 16), # 9:16 | |
(5, 4), # 5:4 | |
(3, 4), # 3:4 | |
(2, 3) # 2:3 | |
] | |
def get_target_size(aspect_ratio, max_size=256): | |
h_ratio, w_ratio = aspect_ratio | |
if h_ratio > w_ratio: | |
height = max_size | |
# print(w_ratio, h_ratio) | |
width = int(max_size * w_ratio / h_ratio) | |
else: | |
width = max_size | |
height = int(max_size * h_ratio / w_ratio) | |
return (height, width) | |
aspect = aspect_ratios[random.randint(0, len(aspect_ratios) - 1)] | |
shape = get_target_size(aspect, 512) | |
images = torch.stack([transforms.Resize(size=shape)(example["image"]) for example in data]) | |
masks = torch.stack([transforms.Resize(size=shape)(example["mask"]) for example in data]) | |
img_text_input_ids = torch.cat([example["img_text_input_ids"] for example in data], dim=0) | |
mask_text_input_ids = torch.cat([example["mask_text_input_ids"] for example in data], dim=0) | |
return { | |
"images": images, | |
"masks": masks, | |
"img_text_input_ids": img_text_input_ids, | |
"mask_text_input_ids": mask_text_input_ids, | |
} | |