Spaces:
Running
Running
File size: 4,758 Bytes
984b1c3 ba17c5b 984b1c3 d00e30a d3f9ca8 984b1c3 1c2f991 ba17c5b 1c2f991 984b1c3 91a732d 984b1c3 ba17c5b 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d00e30a 908fc7b d00e30a 91a732d d3f9ca8 ba17c5b 91a732d bb4e136 ba17c5b d3f9ca8 d00e30a 91a732d ba17c5b 91a732d ba17c5b 908fc7b d00e30a 908fc7b 91a732d ba17c5b d3f9ca8 ba17c5b d3f9ca8 ba17c5b 91a732d 908fc7b 91a732d ba17c5b 91a732d 95d0b08 490bf43 ba17c5b 984b1c3 9cebca9 d3f9ca8 d00e30a d3f9ca8 d00e30a 490bf43 d00e30a 91a732d 490bf43 984b1c3 d3f9ca8 984b1c3 ba17c5b 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 490bf43 984b1c3 d3f9ca8 984b1c3 490bf43 91a732d ba17c5b 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def refine_mask(mask):
"""Refines mask using morphological operations"""
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
return cv2.GaussianBlur(mask, (5, 5), 0)
def segment_dress(image_np):
"""Segment dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
return refine_mask(dress_mask)
def apply_grabcut(image_np, dress_mask):
"""Refines the mask using GrabCut to avoid color bleeding"""
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_FGD, cv2.GC_BGD).astype('uint8')
rect = (10, 10, image_np.shape[1] - 10, image_np.shape[0] - 10)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
refined_mask = np.where((mask == 2) | (mask == 0), 0, 255).astype("uint8")
return refine_mask(refined_mask)
def recolor_dress(image_np, dress_mask, target_color):
"""Changes dress color while keeping texture & lighting intact"""
# Convert target color to LAB
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Convert image to LAB
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Compute mean LAB values of dress pixels
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np # No dress detected
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
# Apply histogram-based color transfer
img_lab[..., 1] = np.where(dress_mask > 128, img_lab[..., 1] - mean_A + target_color_lab[1], img_lab[..., 1])
img_lab[..., 2] = np.where(dress_mask > 128, img_lab[..., 2] - mean_B + target_color_lab[2], img_lab[..., 2])
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
# Create feathered mask for smooth blending
feathered_mask = cv2.GaussianBlur(dress_mask, (15, 15), 10)
# Blend the recolored dress with the original image
img_final = (image_np * (1 - feathered_mask[..., None] / 255) + img_recolored * (feathered_mask[..., None] / 255)).astype(np.uint8)
return img_final
def change_dress_color(image_path, color):
"""Main function to change dress color naturally"""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Further refine mask with GrabCut
dress_mask = apply_grabcut(img_np, dress_mask)
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)
# Apply recoloring with blending
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="AI-Powered Dress Color Changer",
description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)
if __name__ == "__main__":
demo.launch() |