Spaces:
Running
Running
File size: 3,818 Bytes
984b1c3 9bbf95a 984b1c3 9bbf95a d3f9ca8 984b1c3 1c2f991 9bbf95a 1c2f991 984b1c3 9bbf95a 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 9bbf95a d3f9ca8 9bbf95a bb4e136 9bbf95a d3f9ca8 9bbf95a bb4e136 9bbf95a d3f9ca8 9bbf95a d3f9ca8 490bf43 d3f9ca8 9bbf95a d3f9ca8 9bbf95a d3f9ca8 95d0b08 490bf43 9bbf95a 984b1c3 9cebca9 d3f9ca8 bb4e136 490bf43 984b1c3 d3f9ca8 984b1c3 e49358d 490bf43 9bbf95a 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 490bf43 984b1c3 d3f9ca8 984b1c3 490bf43 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net Model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def segment_dress(image_np):
"""Segment the dress using U²-Net and GrabCut."""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
u2net_mask = (output > 0.5).astype(np.uint8) * 255
u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Apply GrabCut to refine the mask
mask = np.zeros(image_np.shape[:2], np.uint8)
mask[u2net_mask > 128] = cv2.GC_FGD
mask[u2net_mask <= 128] = cv2.GC_BGD
bg_model = np.zeros((1, 65), np.float64)
fg_model = np.zeros((1, 65), np.float64)
cv2.grabCut(image_np, mask, None, bg_model, fg_model, 5, cv2.GC_INIT_WITH_MASK)
mask = np.where((mask == 2) | (mask == 0), 0, 255).astype(np.uint8)
return mask
def recolor_dress(image_np, mask, target_color):
"""Recolor the dress while keeping texture, shadows, and designs."""
# Convert to LAB color space
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Target color in LAB
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Preserve lightness (L) and change only chromatic channels (A & B)
blend_factor = 0.8
img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
return img_recolored
def change_dress_color(image_path, color):
"""Change the dress color while preserving texture and design details."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Get dress segmentation mask
mask = segment_dress(img_np)
if mask is None:
return img # No dress detected
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Apply recoloring logic
img_recolored = recolor_dress(img_np, mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)
if __name__ == "__main__":
demo.launch() |