detecting_dress / app.py
gaur3009's picture
Update app.py
0c481b6 verified
raw
history blame
4.75 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
import matplotlib.colors as mcolors
# Load U²-Net
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# Util to get BGR color from name
def get_bgr_from_color_name(color_name):
try:
rgb = mcolors.to_rgb(color_name.lower())
return tuple(int(255 * c) for c in rgb[::-1]) # Convert to BGR
except:
return (0, 0, 255) # Default to red
# Mask refinement
def refine_mask(mask):
close_kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
erode_kernel = np.ones((3, 3), np.uint8)
mask = cv2.erode(mask, erode_kernel, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
return cv2.GaussianBlur(mask, (5, 5), 1.5)
# U²-Net segmentation
def segment_dress(image_np):
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
output = (output - output.min()) / (output.max() - output.min() + 1e-8)
adaptive_thresh = np.mean(output) + 0.2
dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
return refine_mask(cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST))
# Optional GrabCut refinement
def apply_grabcut(image_np, dress_mask):
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
coords = cv2.findNonZero(dress_mask)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
refined = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined)
# LAB color recoloring
def recolor_dress(image_np, dress_mask, target_color):
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
return (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)
# Main function
def change_dress_color(img, color_prompt):
if img is None or not color_prompt:
return img
img_np = np.array(img)
target_bgr = get_bgr_from_color_name(color_prompt)
try:
dress_mask = segment_dress(img_np)
if np.sum(dress_mask) < 1000:
return img
dress_mask = apply_grabcut(img_np, dress_mask)
img_recolored = recolor_dress(img_np, dress_mask, target_bgr)
return Image.fromarray(img_recolored)
except Exception as e:
print(f"Error: {e}")
return img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🎨 AI Dress Recolorer - Prompt Based")
gr.Markdown("Upload an image and type a color (e.g., 'lavender', 'light green', 'royal blue').")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
color_input = gr.Textbox(label="Enter Dress Color", placeholder="e.g. crimson, lavender, sky blue")
recolor_btn = gr.Button("Apply New Color")
with gr.Column():
output_image = gr.Image(type="pil", label="Recolored Result")
recolor_btn.click(fn=change_dress_color, inputs=[input_image, color_input], outputs=output_image)
if __name__ == "__main__":
demo.launch()