English
File size: 3,503 Bytes
ede298f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import random
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image


class SemanticSegmentationDataset(Dataset):
    def __init__(
        self,
        root_dir: Path,
        split_json: Path,
        split: str,
        mode: str,
        resize_size: Tuple[int, int],
        crop_size: Tuple[int, int],
        class_map: Dict[Tuple[int, ...], int],
        transform: transforms.Compose = None
    ):
        """
        Dataset for semantic segmentation, zones structured as:
          root_dir/zone/images/*.jpg
          root_dir/zone/masks/*.png
        """
        self.root_dir    = Path(root_dir)
        self.mode        = mode
        self.resize_size = resize_size
        self.crop_size   = crop_size
        self.transform   = transform
        self.pairs       = self._gather_pairs(split_json, split)

        self.lut = np.full(256, fill_value=255, dtype=np.uint8)
        for keys, v in class_map.items():
            self.lut[list(keys)] = v

    def _gather_pairs(self, split_json: Path, split: str) -> List[Tuple[Path, Path]]:
        with open(split_json, 'r') as f:
            split_data = json.load(f)
        dirs = split_data.get(split, [])
        pairs = []
        for zone in sorted(dirs):
            images_dir = self.root_dir / zone / "images"
            masks_dir  = self.root_dir / zone / "masks"
            if not images_dir.is_dir():
                continue
            if self.mode in ("train", "val") and not masks_dir.is_dir():
                continue
            for img_path in images_dir.glob("*.JPG"):
                if self.mode in ["test", "test3d", "val3d"]:
                    pairs.append((img_path, None))
                else:
                    mask_path = masks_dir / img_path.name.replace(".JPG", ".png")
                    if mask_path.exists():
                        pairs.append((img_path, mask_path))
        return pairs

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_path, mask_path = self.pairs[idx]
        image = Image.open(img_path).convert("RGB").resize(self.resize_size, resample=Image.BILINEAR)
        image_np = np.array(image)

        if mask_path is not None:
            mask = Image.open(mask_path).resize(self.resize_size, resample=Image.NEAREST)
            mask_np = self.lut[np.array(mask)]
        else:
            mask_np = None

        if self.mode == 'train':
            h, w = image_np.shape[:2]
            ch, cw = self.crop_size
            if h < ch or w < cw:
                raise RuntimeError(f"Image {img_path} size ({h},{w}) < crop {self.crop_size}")
            top = random.randint(0, h - ch)
            left = random.randint(0, w - cw)
            image_np = image_np[top:top+ch, left:left+cw]
            mask_np = mask_np[top:top+ch, left:left+cw]

        image_tensor = self.transform(Image.fromarray(image_np)) if self.transform else torch.from_numpy(image_np).permute(2,0,1).float()/255.0

        if mask_np is None:
            zone = img_path.parent.parent.name
            filename = img_path.name.replace('.JPG', '.png')
            path_pred = f"{zone}/{filename}"
            return image_tensor, path_pred
        else:
            mask_tensor = torch.from_numpy(mask_np).long()
            return image_tensor, mask_tensor