Spaces:
Running
Running
File size: 4,751 Bytes
984b1c3 0bda8ba 0c481b6 984b1c3 0c481b6 6e5e70e 984b1c3 0c481b6 1c2f991 984b1c3 0c481b6 91a732d 0bda8ba 0c481b6 91a732d 0c481b6 984b1c3 d3f9ca8 984b1c3 d00e30a 0bda8ba 0c481b6 0bda8ba 0c481b6 91a732d 0c481b6 91a732d 4175fd1 0bda8ba 0c481b6 d3f9ca8 0c481b6 ba17c5b d3f9ca8 ba17c5b 91a732d ba17c5b 908fc7b 0bda8ba d00e30a 908fc7b 12f978b 0bda8ba 12f978b d3f9ca8 12f978b 0bda8ba a22d3b1 908fc7b 0c481b6 7da70bd 0c481b6 0bda8ba 0c481b6 0bda8ba 0c481b6 0bda8ba 0c481b6 0bda8ba 984b1c3 0c481b6 7da70bd 0c481b6 7da70bd 0c481b6 7da70bd 0c481b6 7da70bd 0c481b6 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
import matplotlib.colors as mcolors
# Load U²-Net
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# Util to get BGR color from name
def get_bgr_from_color_name(color_name):
try:
rgb = mcolors.to_rgb(color_name.lower())
return tuple(int(255 * c) for c in rgb[::-1]) # Convert to BGR
except:
return (0, 0, 255) # Default to red
# Mask refinement
def refine_mask(mask):
close_kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
erode_kernel = np.ones((3, 3), np.uint8)
mask = cv2.erode(mask, erode_kernel, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
return cv2.GaussianBlur(mask, (5, 5), 1.5)
# U²-Net segmentation
def segment_dress(image_np):
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
output = (output - output.min()) / (output.max() - output.min() + 1e-8)
adaptive_thresh = np.mean(output) + 0.2
dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
return refine_mask(cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST))
# Optional GrabCut refinement
def apply_grabcut(image_np, dress_mask):
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
coords = cv2.findNonZero(dress_mask)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
refined = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined)
# LAB color recoloring
def recolor_dress(image_np, dress_mask, target_color):
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
return (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)
# Main function
def change_dress_color(img, color_prompt):
if img is None or not color_prompt:
return img
img_np = np.array(img)
target_bgr = get_bgr_from_color_name(color_prompt)
try:
dress_mask = segment_dress(img_np)
if np.sum(dress_mask) < 1000:
return img
dress_mask = apply_grabcut(img_np, dress_mask)
img_recolored = recolor_dress(img_np, dress_mask, target_bgr)
return Image.fromarray(img_recolored)
except Exception as e:
print(f"Error: {e}")
return img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🎨 AI Dress Recolorer - Prompt Based")
gr.Markdown("Upload an image and type a color (e.g., 'lavender', 'light green', 'royal blue').")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
color_input = gr.Textbox(label="Enter Dress Color", placeholder="e.g. crimson, lavender, sky blue")
recolor_btn = gr.Button("Apply New Color")
with gr.Column():
output_image = gr.Image(type="pil", label="Recolored Result")
recolor_btn.click(fn=change_dress_color, inputs=[input_image, color_input], outputs=output_image)
if __name__ == "__main__":
demo.launch() |