ClothQuill / parser /segformer_parser.py
Bismay
Initial commit
475e066
import torch
import numpy as np
from PIL import Image
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import torch.nn.functional as F
import logging
import time
from typing import Tuple, Optional
logger = logging.getLogger('looks.studio.segformer')
class SegformerParser:
def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"):
self.start_time = time.time()
logger.info(f"Initializing SegformerParser with model: {model_path}")
try:
self.processor = SegformerImageProcessor.from_pretrained(model_path)
self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
self.model.to(self.device)
# Define clothing-related labels
self.clothing_labels = {
4: "upper-clothes",
5: "skirt",
6: "pants",
7: "dress",
8: "belt",
9: "left-shoe",
10: "right-shoe",
14: "left-arm",
15: "right-arm",
17: "scarf"
}
logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to initialize SegformerParser: {str(e)}")
raise
def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]:
"""Resize image while maintaining aspect ratio if it exceeds max_size"""
width, height = image.size
scale = 1.0
if width > max_size or height > max_size:
scale = max_size / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image, scale
def _validate_image(self, image: Image.Image) -> bool:
"""Validate input image"""
if not isinstance(image, Image.Image):
logger.error("Input is not a PIL Image")
return False
if image.mode not in ['RGB', 'RGBA']:
logger.error(f"Unsupported image mode: {image.mode}")
return False
width, height = image.size
if width < 64 or height < 64:
logger.error(f"Image too small: {width}x{height}")
return False
if width > 4096 or height > 4096:
logger.error(f"Image too large: {width}x{height}")
return False
return True
def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]:
"""Generate segmentation mask for clothing"""
start_time = time.time()
logger.info(f"Starting segmentation for image size: {image.size}")
try:
# Validate input image
if not self._validate_image(image):
return None
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
logger.info("Converting RGBA to RGB")
image = image.convert('RGB')
# Resize image if too large
image, scale = self._resize_image(image)
# Process the image
logger.info("Processing image with Segformer")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to original image size
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get the segmentation mask
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Create a binary mask for clothing
mask = torch.zeros_like(pred_seg)
for label_id in self.clothing_labels.keys():
mask[pred_seg == label_id] = 255
# Convert to PIL Image
mask = Image.fromarray(mask.numpy().astype(np.uint8))
# Resize mask back to original size if needed
if scale != 1.0:
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
logger.info(f"Resizing mask back to original size: {original_size}")
mask = mask.resize(original_size, Image.Resampling.NEAREST)
logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
return mask
except Exception as e:
logger.error(f"Error during segmentation: {str(e)}")
return None
def get_all_masks(self, image: Image.Image) -> dict:
"""Return a dict of binary masks for each clothing part label."""
start_time = time.time()
logger.info(f"Starting per-part segmentation for image size: {image.size}")
masks = {}
try:
# Validate input image
if not self._validate_image(image):
return masks
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
logger.info("Converting RGBA to RGB")
image = image.convert('RGB')
# Resize image if too large
image, scale = self._resize_image(image)
# Process the image
logger.info("Processing image with Segformer for all masks")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
# For each clothing label, create a binary mask
for label_id, part_name in self.clothing_labels.items():
mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255
mask_img = Image.fromarray(mask)
# Resize mask back to original size if needed
if scale != 1.0:
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST)
masks[part_name] = mask_img
logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds")
return masks
except Exception as e:
logger.error(f"Error during per-part segmentation: {str(e)}")
return masks