|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
|
|
class TransformNet(nn.Module): |
|
"""Transformation Network for PointCloud Encoding""" |
|
def __init__(self, input_dim=6): |
|
super(TransformNet, self).__init__() |
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU() |
|
) |
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(64, 128, kernel_size=(1, 1)), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU() |
|
) |
|
self.conv3 = nn.Sequential( |
|
nn.Conv1d(128, 1024, kernel_size=1), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU() |
|
) |
|
self.fc = nn.Linear(1024, 512) |
|
|
|
def forward(self, x): |
|
if x.dim() == 5: |
|
x = x.squeeze(-1) |
|
|
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
|
|
x = x.squeeze(-1) |
|
x = self.conv3(x) |
|
|
|
return self.fc(x.max(dim=-1)[0]) |
|
|
|
class PointCloudEncoder(nn.Module): |
|
"""Point Cloud Encoder (pc_enc)""" |
|
def __init__(self): |
|
super(PointCloudEncoder, self).__init__() |
|
self.transform_net = TransformNet() |
|
self.convs = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Conv2d(512, 256, kernel_size=(1, 1)), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU() |
|
), |
|
nn.Sequential( |
|
nn.Conv2d(256, 128, kernel_size=(1, 1)), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU() |
|
), |
|
nn.Sequential( |
|
nn.Conv1d(128, 64, kernel_size=1), |
|
nn.BatchNorm1d(64), |
|
nn.ReLU() |
|
) |
|
]) |
|
self.lin_global = nn.Linear(64, 128) |
|
|
|
def forward(self, x): |
|
x = self.transform_net(x) |
|
|
|
for i, conv in enumerate(self.convs): |
|
if i < 2: |
|
x = conv(x) |
|
else: |
|
x = x.squeeze(-1) |
|
x = conv(x) |
|
|
|
return self.lin_global(x.max(dim=-1)[0]) |
|
|
|
|
|
class GarmentEncoder(nn.Module): |
|
"""Garment Feature Encoder (garm_enc)""" |
|
def __init__(self, num_classes=18, feature_dim=64): |
|
super(GarmentEncoder, self).__init__() |
|
self.garm_embedding = nn.Parameter(torch.randn(num_classes, feature_dim)) |
|
self.attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4) |
|
self.ff = nn.Sequential( |
|
nn.Linear(feature_dim, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 64) |
|
) |
|
self.norm = nn.LayerNorm(64) |
|
|
|
def forward(self, x, clothing_classes): |
|
garment_features = self.garm_embedding[clothing_classes] |
|
attn_output, _ = self.attn(x, garment_features, garment_features) |
|
return self.norm(self.ff(attn_output)) |
|
|
|
class SegmentationDecoder(nn.Module): |
|
"""Segmentation Decoder (segm_dec)""" |
|
def __init__(self, input_dim=192, num_classes=18): |
|
super(SegmentationDecoder, self).__init__() |
|
self.layers = nn.Sequential( |
|
nn.Linear(input_dim, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class CloseNet(nn.Module): |
|
"""Complete CloSe-Net Model""" |
|
def __init__(self): |
|
super(CloseNet, self).__init__() |
|
self.pc_enc = PointCloudEncoder() |
|
self.garm_enc = GarmentEncoder() |
|
self.segm_dec = SegmentationDecoder() |
|
|
|
def forward(self, point_cloud, clothing_classes): |
|
pc_features = self.pc_enc(point_cloud) |
|
garm_features = self.garm_enc(pc_features, clothing_classes) |
|
features = torch.cat((pc_features, garm_features), dim=1) |
|
return self.segm_dec(features) |
|
|
|
|
|
model_path = "model_arch/closenet.pth" |
|
model = CloseNet() |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False) |
|
model.eval() |
|
|
|
def segment_dress(image): |
|
"""Detect and segment the dress from the image.""" |
|
img = Image.fromarray(image).convert("RGB") |
|
img = np.array(img).transpose(2, 0, 1) |
|
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 |
|
|
|
|
|
img = torch.cat((img, img), dim=1) |
|
|
|
with torch.no_grad(): |
|
mask = model(img, clothing_classes=torch.arange(18)) |
|
mask = mask.squeeze().numpy() |
|
|
|
mask = (mask > 0.5).astype(np.uint8) * 255 |
|
return mask |
|
|
|
def change_color(image, color): |
|
"""Change dress color based on segmentation.""" |
|
mask = segment_dress(image) |
|
color_bgr = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (4, 2, 0)) |
|
|
|
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
image_bgr[mask == 255] = color_bgr |
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
|
|
return image_rgb |
|
|
|
|
|
interface = gr.Interface( |
|
fn=change_color, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Upload a dress image"), |
|
gr.ColorPicker(label="Choose color") |
|
], |
|
outputs=gr.Image(type="numpy", label="Color-changed image"), |
|
title="AI Dress Color Changer", |
|
description="Upload an image of a dress and change its color using AI segmentation." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |