Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
import torch.nn.functional as F | |
import logging | |
import time | |
from typing import Tuple, Optional | |
logger = logging.getLogger('looks.studio.segformer') | |
class SegformerParser: | |
def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"): | |
self.start_time = time.time() | |
logger.info(f"Initializing SegformerParser with model: {model_path}") | |
try: | |
self.processor = SegformerImageProcessor.from_pretrained(model_path) | |
self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
self.model.to(self.device) | |
# Define clothing-related labels | |
self.clothing_labels = { | |
4: "upper-clothes", | |
5: "skirt", | |
6: "pants", | |
7: "dress", | |
8: "belt", | |
9: "left-shoe", | |
10: "right-shoe", | |
14: "left-arm", | |
15: "right-arm", | |
17: "scarf" | |
} | |
logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds") | |
except Exception as e: | |
logger.error(f"Failed to initialize SegformerParser: {str(e)}") | |
raise | |
def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]: | |
"""Resize image while maintaining aspect ratio if it exceeds max_size""" | |
width, height = image.size | |
scale = 1.0 | |
if width > max_size or height > max_size: | |
scale = max_size / max(width, height) | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}") | |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
return image, scale | |
def _validate_image(self, image: Image.Image) -> bool: | |
"""Validate input image""" | |
if not isinstance(image, Image.Image): | |
logger.error("Input is not a PIL Image") | |
return False | |
if image.mode not in ['RGB', 'RGBA']: | |
logger.error(f"Unsupported image mode: {image.mode}") | |
return False | |
width, height = image.size | |
if width < 64 or height < 64: | |
logger.error(f"Image too small: {width}x{height}") | |
return False | |
if width > 4096 or height > 4096: | |
logger.error(f"Image too large: {width}x{height}") | |
return False | |
return True | |
def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]: | |
"""Generate segmentation mask for clothing""" | |
start_time = time.time() | |
logger.info(f"Starting segmentation for image size: {image.size}") | |
try: | |
# Validate input image | |
if not self._validate_image(image): | |
return None | |
# Convert RGBA to RGB if necessary | |
if image.mode == 'RGBA': | |
logger.info("Converting RGBA to RGB") | |
image = image.convert('RGB') | |
# Resize image if too large | |
image, scale = self._resize_image(image) | |
# Process the image | |
logger.info("Processing image with Segformer") | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits.cpu() | |
# Upsample logits to original image size | |
upsampled_logits = F.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get the segmentation mask | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
# Create a binary mask for clothing | |
mask = torch.zeros_like(pred_seg) | |
for label_id in self.clothing_labels.keys(): | |
mask[pred_seg == label_id] = 255 | |
# Convert to PIL Image | |
mask = Image.fromarray(mask.numpy().astype(np.uint8)) | |
# Resize mask back to original size if needed | |
if scale != 1.0: | |
original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) | |
logger.info(f"Resizing mask back to original size: {original_size}") | |
mask = mask.resize(original_size, Image.Resampling.NEAREST) | |
logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds") | |
return mask | |
except Exception as e: | |
logger.error(f"Error during segmentation: {str(e)}") | |
return None | |
def get_all_masks(self, image: Image.Image) -> dict: | |
"""Return a dict of binary masks for each clothing part label.""" | |
start_time = time.time() | |
logger.info(f"Starting per-part segmentation for image size: {image.size}") | |
masks = {} | |
try: | |
# Validate input image | |
if not self._validate_image(image): | |
return masks | |
# Convert RGBA to RGB if necessary | |
if image.mode == 'RGBA': | |
logger.info("Converting RGBA to RGB") | |
image = image.convert('RGB') | |
# Resize image if too large | |
image, scale = self._resize_image(image) | |
# Process the image | |
logger.info("Processing image with Segformer for all masks") | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = F.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
# For each clothing label, create a binary mask | |
for label_id, part_name in self.clothing_labels.items(): | |
mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255 | |
mask_img = Image.fromarray(mask) | |
# Resize mask back to original size if needed | |
if scale != 1.0: | |
original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) | |
mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST) | |
masks[part_name] = mask_img | |
logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds") | |
return masks | |
except Exception as e: | |
logger.error(f"Error during per-part segmentation: {str(e)}") | |
return masks |