import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr import numpy as np from PIL import Image import cv2 class TransformNet(nn.Module): """Transformation Network for PointCloud Encoding""" def __init__(self, input_dim=6): # ✅ Ensure input has 6 channels super(TransformNet, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # ✅ Conv2d (Matches checkpoint) nn.BatchNorm2d(64), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=(1, 1)), # ✅ Conv2d (Matches checkpoint) nn.BatchNorm2d(128), nn.ReLU() ) self.conv3 = nn.Sequential( nn.Conv1d(128, 1024, kernel_size=1), # ✅ Conv1d to match `[1024, 128, 1]` nn.BatchNorm1d(1024), nn.ReLU() ) self.fc = nn.Linear(1024, 512) def forward(self, x): if x.dim() == 5: x = x.squeeze(-1) # ✅ Remove extra dimension if exists x = self.conv1(x) x = self.conv2(x) x = x.squeeze(-1) # ✅ Ensure shape is [batch, channels, length] before Conv1d x = self.conv3(x) # ✅ Now Conv1d receives correct input shape [batch, channels, length] return self.fc(x.max(dim=-1)[0]) # ✅ Ensure correct pooling class PointCloudEncoder(nn.Module): """Point Cloud Encoder (pc_enc)""" def __init__(self): super(PointCloudEncoder, self).__init__() self.transform_net = TransformNet() self.convs = nn.ModuleList([ nn.Sequential( nn.Conv2d(512, 256, kernel_size=(1, 1)), # ✅ Conv2d (Matches checkpoint) nn.BatchNorm2d(256), nn.ReLU() ), nn.Sequential( nn.Conv2d(256, 128, kernel_size=(1, 1)), # ✅ Conv2d (Matches checkpoint) nn.BatchNorm2d(128), nn.ReLU() ), nn.Sequential( nn.Conv1d(128, 64, kernel_size=1), # ✅ Conv1d to match `[64, 128, 1]` nn.BatchNorm1d(64), nn.ReLU() ) ]) self.lin_global = nn.Linear(64, 128) def forward(self, x): x = self.transform_net(x) for i, conv in enumerate(self.convs): if i < 2: x = conv(x) # ✅ Conv2d keeps 4D else: x = x.squeeze(-1) # ✅ Ensure shape is [batch, channels, length] before Conv1d x = conv(x) # ✅ Conv1d now works with the correct input return self.lin_global(x.max(dim=-1)[0]) # ✅ Fix pooling class GarmentEncoder(nn.Module): """Garment Feature Encoder (garm_enc)""" def __init__(self, num_classes=18, feature_dim=64): super(GarmentEncoder, self).__init__() self.garm_embedding = nn.Parameter(torch.randn(num_classes, feature_dim)) self.attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4) self.ff = nn.Sequential( nn.Linear(feature_dim, 128), nn.ReLU(), nn.Linear(128, 64) ) self.norm = nn.LayerNorm(64) def forward(self, x, clothing_classes): garment_features = self.garm_embedding[clothing_classes] attn_output, _ = self.attn(x, garment_features, garment_features) return self.norm(self.ff(attn_output)) class SegmentationDecoder(nn.Module): """Segmentation Decoder (segm_dec)""" def __init__(self, input_dim=192, num_classes=18): super(SegmentationDecoder, self).__init__() self.layers = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, num_classes) ) def forward(self, x): return self.layers(x) class CloseNet(nn.Module): """Complete CloSe-Net Model""" def __init__(self): super(CloseNet, self).__init__() self.pc_enc = PointCloudEncoder() self.garm_enc = GarmentEncoder() self.segm_dec = SegmentationDecoder() def forward(self, point_cloud, clothing_classes): pc_features = self.pc_enc(point_cloud) garm_features = self.garm_enc(pc_features, clothing_classes) features = torch.cat((pc_features, garm_features), dim=1) return self.segm_dec(features) # Load Pretrained Model model_path = "model_arch/closenet.pth" model = CloseNet() model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False) model.eval() def segment_dress(image): """Detect and segment the dress from the image.""" img = Image.fromarray(image).convert("RGB") img = np.array(img).transpose(2, 0, 1) # Convert to [C, H, W] → [3, H, W] img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize to [1, 3, H, W] # ✅ Duplicate channels to match the expected 6-channel input img = torch.cat((img, img), dim=1) # Convert [1, 3, H, W] → [1, 6, H, W] with torch.no_grad(): mask = model(img, clothing_classes=torch.arange(18)) # ✅ Correct input shape mask = mask.squeeze().numpy() mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask return mask def change_color(image, color): """Change dress color based on segmentation.""" mask = segment_dress(image) color_bgr = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (4, 2, 0)) # Convert HEX to BGR image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) image_bgr[mask == 255] = color_bgr # Apply new color where mask is present image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) return image_rgb # Gradio Interface interface = gr.Interface( fn=change_color, inputs=[ gr.Image(type="numpy", label="Upload a dress image"), gr.ColorPicker(label="Choose color") ], outputs=gr.Image(type="numpy", label="Color-changed image"), title="AI Dress Color Changer", description="Upload an image of a dress and change its color using AI segmentation." ) if __name__ == "__main__": interface.launch()