SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/data_loader-checkpoint.py
import os | |
import cv2 | |
import numpy as np | |
from torch.utils.data import Dataset as BaseDataset | |
class SegmentationDataset(BaseDataset): | |
"""Dataset class for semantic segmentation task.""" | |
def __init__(self, data_dir, classes=['background', 'object'], | |
augmentation=None, preprocessing=None): | |
self.image_dir = os.path.join(data_dir, 'Images') | |
self.mask_dir = os.path.join(data_dir, 'Masks') | |
for dir_path in [self.image_dir, self.mask_dir]: | |
if not os.path.exists(dir_path): | |
raise FileNotFoundError(f"Directory not found: {dir_path}") | |
self.filenames = self._get_filenames() | |
self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames] | |
self.mask_paths = self._get_mask_paths() | |
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] | |
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] | |
self.augmentation = augmentation | |
self.preprocessing = preprocessing | |
def __getitem__(self, index): | |
image = self._load_image(self.image_paths[index]) | |
mask = self._load_mask(self.mask_paths[index]) | |
if self.augmentation: | |
processed = self.augmentation(image=image, mask=mask) | |
image, mask = processed['image'], processed['mask'] | |
if self.preprocessing: | |
processed = self.preprocessing(image=image, mask=mask) | |
image, mask = processed['image'], processed['mask'] | |
return image, mask | |
def __len__(self): | |
return len(self.filenames) | |
def _get_filenames(self): | |
"""Returns sorted list of filenames, excluding directories.""" | |
files = sorted(os.listdir(self.image_dir)) | |
return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))] | |
def _get_mask_paths(self): | |
"""Generates corresponding mask paths for each image.""" | |
mask_paths = [] | |
for image_file in self.filenames: | |
name, _ = os.path.splitext(image_file) | |
mask_paths.append(os.path.join(self.mask_dir, f"{name}.png")) | |
return mask_paths | |
def _load_image(self, path): | |
"""Loads and converts image to RGB.""" | |
image = cv2.imread(path) | |
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
def _load_mask(self, path): | |
"""Loads and processes segmentation mask.""" | |
mask = cv2.imread(path, 0) | |
masks = [(mask == value) for value in self.class_values] | |
mask = np.stack(masks, axis=-1).astype('float') | |
return mask | |
class InferenceDataset(BaseDataset): | |
"""Dataset class for inference without ground truth masks.""" | |
def __init__(self, data_dir, classes=['background', 'object'], | |
augmentation=None, preprocessing=None): | |
self.filenames = sorted([ | |
f for f in os.listdir(data_dir) | |
if not os.path.isdir(os.path.join(data_dir, f)) | |
]) | |
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] | |
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] | |
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] | |
self.augmentation = augmentation | |
self.preprocessing = preprocessing | |
def __getitem__(self, index): | |
image = cv2.imread(self.image_paths[index]) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
original_height, original_width = image.shape[:2] | |
if self.augmentation: | |
image = self.augmentation(image=image)['image'] | |
if self.preprocessing: | |
image = self.preprocessing(image=image)['image'] | |
return image, original_height, original_width | |
def __len__(self): | |
return len(self.filenames) | |
class StreamingDataset(BaseDataset): | |
"""Dataset class optimized for video frame processing.""" | |
def __init__(self, data_dir, classes=['background', 'object'], | |
augmentation=None, preprocessing=None): | |
self.filenames = self._get_frame_filenames(data_dir) | |
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames] | |
self.target_classes = [cls for cls in classes if cls.lower() != 'background'] | |
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background'] | |
self.augmentation = augmentation | |
self.preprocessing = preprocessing | |
def _get_frame_filenames(self, directory): | |
"""Returns sorted list of frame filenames.""" | |
files = sorted(os.listdir(directory)) | |
return [f for f in files if (('frame' in f or 'Image' in f) and | |
f.lower().endswith('jpg') and | |
not os.path.isdir(os.path.join(directory, f)))] | |
def __getitem__(self, index): | |
return InferenceDataset.__getitem__(self, index) | |
def __len__(self): | |
return len(self.filenames) |