Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Callable, Optional | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| class Preprocessed_fastMRI(torch.utils.data.Dataset): | |
| """FastMRI from preprocessed data for faster lading.""" | |
| def __init__( | |
| self, | |
| root: str, | |
| transform: Optional[Callable] = None, | |
| preprocess: bool = False, | |
| ) -> None: | |
| self.root = root | |
| self.transform = transform | |
| self.preprocess = preprocess | |
| # should contain all the information to load a data sample from the storage | |
| self.sample_identifiers = [] | |
| # append all filenames in self.root ending with .pt | |
| for root, _, files in os.walk(self.root): | |
| for file in files: | |
| if file.endswith(".pt"): | |
| self.sample_identifiers.append(file) | |
| def __len__(self) -> int: | |
| return len(self.sample_identifiers) | |
| def __getitem__(self, idx: int): | |
| fname = self.sample_identifiers[idx] | |
| tensor = torch.load(os.path.join(self.root, fname), weights_only=True) | |
| img = tensor['data'].float() | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if not self.preprocess: | |
| return img | |
| else: | |
| # remove extension and prefix from filename | |
| fname = Path(fname).stem | |
| return img, fname | |
| class Preprocessed_LIDCIDRI(torch.utils.data.Dataset): | |
| """FastMRI from preprocessed data for faster lading.""" | |
| def __init__( | |
| self, | |
| root: str, | |
| transform: Optional[Callable] = None, | |
| ) -> None: | |
| self.root = root | |
| self.transform = transform | |
| # should contain all the information to load a data sample from the storage | |
| self.sample_identifiers = [] | |
| # append all filenames in self.root ending with .pt | |
| for root, _, files in os.walk(self.root): | |
| for file in files: | |
| if file.endswith(".pt"): | |
| self.sample_identifiers.append(file) | |
| def __len__(self) -> int: | |
| return len(self.sample_identifiers) | |
| def __getitem__(self, idx: int): | |
| fname = self.sample_identifiers[idx] | |
| tensor = torch.load(os.path.join(self.root, fname), weights_only=True) | |
| img = tensor['data'].float() | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| img = img.unsqueeze(0) # add channel dim | |
| return img | |
| class LsdirMiniDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| root: str, | |
| transform: Optional[Callable] = None, | |
| ) -> None: | |
| self.root = root | |
| self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.jpeg'))] | |
| self.transform = transform | |
| def __len__(self) -> int: | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.root, self.image_files[idx]) | |
| img = Image.open(img_path).convert("RGB") # Ensure consistent 3-channel format | |
| if self.transform: | |
| img = self.transform(img) | |
| return img | |