Spaces:
Running
Running
File size: 4,233 Bytes
984b1c3 ac99583 984b1c3 e49358d 984b1c3 1c2f991 e49358d 1c2f991 984b1c3 9582f0e 984b1c3 9582f0e 1f7ad23 95d0b08 9582f0e 95d0b08 9582f0e 984b1c3 95d0b08 984b1c3 95d0b08 984b1c3 1f7ad23 984b1c3 e49358d 984b1c3 1f7ad23 95d0b08 984b1c3 e49358d 95d0b08 1f7ad23 65014ef 1f7ad23 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 95d0b08 984b1c3 65014ef 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth" # Ensure this path is correct
model = U2NET(3, 1)
# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def segment_dress(image_np):
"""Segment the dress from the image using U²-Net and refine the mask."""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
mask = (output > 0.5).astype(np.uint8) * 255 # Binary mask
# Resize mask to original image size
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Apply morphological operations for better segmentation
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close small gaps
mask = cv2.GaussianBlur(mask, (21, 21), 10) # Smooth edges for natural blending
return mask
def get_ambient_light(img_np):
"""Estimate ambient lighting from non-dress areas for realistic blending."""
lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
L_channel = lab[:, :, 0] # Lightness channel
return np.median(L_channel) # Median light level in the image
def change_dress_color(image_path, color):
"""Change the dress color naturally while keeping textures and adjusting to lighting."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
mask = segment_dress(img_np)
if mask is None:
return img # No dress detected
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Convert image to LAB color space for better blending
img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
new_color_lab = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2LAB)[0][0]
# Adjust color to match ambient lighting
ambient_light = get_ambient_light(img_np)
img_lab[..., 0] = np.clip(img_lab[..., 0] * (ambient_light / 128), 0, 255) # Normalize lighting
# Preserve texture by modifying only A & B channels
blend_factor = 0.6 # Controls intensity of color change
img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + new_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + new_color_lab[2] * blend_factor, img_lab[..., 2])
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
# Use Poisson blending for seamless integration with the environment
img_recolored = cv2.seamlessClone(img_recolored, img_np, mask, (img_np.shape[1]//2, img_np.shape[0]//2), cv2.MIXED_CLONE)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance naturally."
)
if __name__ == "__main__":
demo.launch() |