Spaces:
Running
Running
File size: 4,479 Bytes
984b1c3 d00e30a 984b1c3 d00e30a d3f9ca8 984b1c3 1c2f991 d00e30a 1c2f991 984b1c3 d00e30a 984b1c3 d00e30a 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d00e30a 9bbf95a d00e30a d3f9ca8 d00e30a bb4e136 d3f9ca8 d00e30a 9bbf95a d00e30a d3f9ca8 95d0b08 490bf43 d00e30a 984b1c3 9cebca9 d3f9ca8 d00e30a d3f9ca8 d00e30a 490bf43 d00e30a 490bf43 984b1c3 d3f9ca8 984b1c3 e49358d 490bf43 d00e30a 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 490bf43 984b1c3 d3f9ca8 984b1c3 490bf43 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def detect_design(image_np):
"""Detects the design on the dress using edge detection and adaptive thresholding."""
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# Use adaptive thresholding to segment the design
adaptive_thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2)
# Detect edges using Canny
edges = cv2.Canny(gray, 50, 150)
# Combine both masks
design_mask = cv2.bitwise_or(adaptive_thresh, edges)
# Morphological operations to remove noise
kernel = np.ones((3, 3), np.uint8)
design_mask = cv2.morphologyEx(design_mask, cv2.MORPH_CLOSE, kernel)
return design_mask
def segment_dress(image_np):
"""Segment the dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
# Convert output to mask
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Morphological operations for smoothness
kernel = np.ones((5, 5), np.uint8)
dress_mask = cv2.morphologyEx(dress_mask, cv2.MORPH_CLOSE, kernel)
return dress_mask
def recolor_dress(image_np, dress_mask, design_mask, target_color):
"""Change dress color while preserving designs"""
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Ensure the design areas are NOT recolored
recolor_mask = cv2.bitwise_and(dress_mask, cv2.bitwise_not(design_mask))
# Apply color change only to the non-design dress areas
blend_factor = 0.8
img_lab[..., 1] = np.where(recolor_mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(recolor_mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
return img_recolored
def change_dress_color(image_path, color):
"""Change the dress color naturally while keeping designs intact."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Detect design on the dress
design_mask = detect_design(img_np)
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Apply recoloring logic
img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)
if __name__ == "__main__":
demo.launch() |