File size: 7,502 Bytes
475e066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
import numpy as np
from PIL import Image
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import torch.nn.functional as F
import logging
import time
from typing import Tuple, Optional

logger = logging.getLogger('looks.studio.segformer')

class SegformerParser:
    def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"):
        self.start_time = time.time()
        logger.info(f"Initializing SegformerParser with model: {model_path}")
        
        try:
            self.processor = SegformerImageProcessor.from_pretrained(model_path)
            self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path)
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info(f"Using device: {self.device}")
            self.model.to(self.device)
            
            # Define clothing-related labels
            self.clothing_labels = {
                4: "upper-clothes",
                5: "skirt",
                6: "pants",
                7: "dress",
                8: "belt",
                9: "left-shoe",
                10: "right-shoe",
                14: "left-arm",
                15: "right-arm",
                17: "scarf"
            }
            
            logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds")
        except Exception as e:
            logger.error(f"Failed to initialize SegformerParser: {str(e)}")
            raise
    
    def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]:
        """Resize image while maintaining aspect ratio if it exceeds max_size"""
        width, height = image.size
        scale = 1.0
        
        if width > max_size or height > max_size:
            scale = max_size / max(width, height)
            new_width = int(width * scale)
            new_height = int(height * scale)
            logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
            image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        
        return image, scale
    
    def _validate_image(self, image: Image.Image) -> bool:
        """Validate input image"""
        if not isinstance(image, Image.Image):
            logger.error("Input is not a PIL Image")
            return False
        
        if image.mode not in ['RGB', 'RGBA']:
            logger.error(f"Unsupported image mode: {image.mode}")
            return False
        
        width, height = image.size
        if width < 64 or height < 64:
            logger.error(f"Image too small: {width}x{height}")
            return False
        
        if width > 4096 or height > 4096:
            logger.error(f"Image too large: {width}x{height}")
            return False
        
        return True
    
    def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]:
        """Generate segmentation mask for clothing"""
        start_time = time.time()
        logger.info(f"Starting segmentation for image size: {image.size}")
        
        try:
            # Validate input image
            if not self._validate_image(image):
                return None
            
            # Convert RGBA to RGB if necessary
            if image.mode == 'RGBA':
                logger.info("Converting RGBA to RGB")
                image = image.convert('RGB')
            
            # Resize image if too large
            image, scale = self._resize_image(image)
            
            # Process the image
            logger.info("Processing image with Segformer")
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)
            
            # Get predictions
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits.cpu()
                
                # Upsample logits to original image size
                upsampled_logits = F.interpolate(
                    logits,
                    size=image.size[::-1],
                    mode="bilinear",
                    align_corners=False,
                )
                
                # Get the segmentation mask
                pred_seg = upsampled_logits.argmax(dim=1)[0]
                
                # Create a binary mask for clothing
                mask = torch.zeros_like(pred_seg)
                for label_id in self.clothing_labels.keys():
                    mask[pred_seg == label_id] = 255
                
                # Convert to PIL Image
                mask = Image.fromarray(mask.numpy().astype(np.uint8))
                
                # Resize mask back to original size if needed
                if scale != 1.0:
                    original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
                    logger.info(f"Resizing mask back to original size: {original_size}")
                    mask = mask.resize(original_size, Image.Resampling.NEAREST)
                
                logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
                return mask
                
        except Exception as e:
            logger.error(f"Error during segmentation: {str(e)}")
            return None 

    def get_all_masks(self, image: Image.Image) -> dict:
        """Return a dict of binary masks for each clothing part label."""
        start_time = time.time()
        logger.info(f"Starting per-part segmentation for image size: {image.size}")
        masks = {}
        try:
            # Validate input image
            if not self._validate_image(image):
                return masks

            # Convert RGBA to RGB if necessary
            if image.mode == 'RGBA':
                logger.info("Converting RGBA to RGB")
                image = image.convert('RGB')

            # Resize image if too large
            image, scale = self._resize_image(image)

            # Process the image
            logger.info("Processing image with Segformer for all masks")
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)

            # Get predictions
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits.cpu()
                upsampled_logits = F.interpolate(
                    logits,
                    size=image.size[::-1],
                    mode="bilinear",
                    align_corners=False,
                )
                pred_seg = upsampled_logits.argmax(dim=1)[0]

                # For each clothing label, create a binary mask
                for label_id, part_name in self.clothing_labels.items():
                    mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255
                    mask_img = Image.fromarray(mask)
                    # Resize mask back to original size if needed
                    if scale != 1.0:
                        original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
                        mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST)
                    masks[part_name] = mask_img

            logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds")
            return masks
        except Exception as e:
            logger.error(f"Error during per-part segmentation: {str(e)}")
            return masks