Spaces:
Sleeping
Sleeping
File size: 7,502 Bytes
475e066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
import numpy as np
from PIL import Image
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import torch.nn.functional as F
import logging
import time
from typing import Tuple, Optional
logger = logging.getLogger('looks.studio.segformer')
class SegformerParser:
def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"):
self.start_time = time.time()
logger.info(f"Initializing SegformerParser with model: {model_path}")
try:
self.processor = SegformerImageProcessor.from_pretrained(model_path)
self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
self.model.to(self.device)
# Define clothing-related labels
self.clothing_labels = {
4: "upper-clothes",
5: "skirt",
6: "pants",
7: "dress",
8: "belt",
9: "left-shoe",
10: "right-shoe",
14: "left-arm",
15: "right-arm",
17: "scarf"
}
logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to initialize SegformerParser: {str(e)}")
raise
def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]:
"""Resize image while maintaining aspect ratio if it exceeds max_size"""
width, height = image.size
scale = 1.0
if width > max_size or height > max_size:
scale = max_size / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image, scale
def _validate_image(self, image: Image.Image) -> bool:
"""Validate input image"""
if not isinstance(image, Image.Image):
logger.error("Input is not a PIL Image")
return False
if image.mode not in ['RGB', 'RGBA']:
logger.error(f"Unsupported image mode: {image.mode}")
return False
width, height = image.size
if width < 64 or height < 64:
logger.error(f"Image too small: {width}x{height}")
return False
if width > 4096 or height > 4096:
logger.error(f"Image too large: {width}x{height}")
return False
return True
def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]:
"""Generate segmentation mask for clothing"""
start_time = time.time()
logger.info(f"Starting segmentation for image size: {image.size}")
try:
# Validate input image
if not self._validate_image(image):
return None
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
logger.info("Converting RGBA to RGB")
image = image.convert('RGB')
# Resize image if too large
image, scale = self._resize_image(image)
# Process the image
logger.info("Processing image with Segformer")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to original image size
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get the segmentation mask
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Create a binary mask for clothing
mask = torch.zeros_like(pred_seg)
for label_id in self.clothing_labels.keys():
mask[pred_seg == label_id] = 255
# Convert to PIL Image
mask = Image.fromarray(mask.numpy().astype(np.uint8))
# Resize mask back to original size if needed
if scale != 1.0:
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
logger.info(f"Resizing mask back to original size: {original_size}")
mask = mask.resize(original_size, Image.Resampling.NEAREST)
logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
return mask
except Exception as e:
logger.error(f"Error during segmentation: {str(e)}")
return None
def get_all_masks(self, image: Image.Image) -> dict:
"""Return a dict of binary masks for each clothing part label."""
start_time = time.time()
logger.info(f"Starting per-part segmentation for image size: {image.size}")
masks = {}
try:
# Validate input image
if not self._validate_image(image):
return masks
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
logger.info("Converting RGBA to RGB")
image = image.convert('RGB')
# Resize image if too large
image, scale = self._resize_image(image)
# Process the image
logger.info("Processing image with Segformer for all masks")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
# For each clothing label, create a binary mask
for label_id, part_name in self.clothing_labels.items():
mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255
mask_img = Image.fromarray(mask)
# Resize mask back to original size if needed
if scale != 1.0:
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST)
masks[part_name] = mask_img
logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds")
return masks
except Exception as e:
logger.error(f"Error during per-part segmentation: {str(e)}")
return masks |