File size: 6,228 Bytes
c7baa39 0083878 efac922 ed2f309 0083878 ed2f309 efac922 b5892cf 0083878 c7baa39 b5892cf 0083878 aa15348 b5892cf 0083878 aa15348 b5892cf 0083878 b5892cf c7baa39 be3d161 b5892cf ed2f309 0083878 ed2f309 0083878 ed2f309 0083878 b5892cf ed2f309 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 b5892cf c7baa39 221b341 c7baa39 b5892cf c7baa39 b99093e c7baa39 b99093e c7baa39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image
import cv2
class TransformNet(nn.Module):
"""Transformation Network for PointCloud Encoding"""
def __init__(self, input_dim=6): # β
Ensure input has 6 channels
super(TransformNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv1d(128, 1024, kernel_size=1), # β
Conv1d to match `[1024, 128, 1]`
nn.BatchNorm1d(1024),
nn.ReLU()
)
self.fc = nn.Linear(1024, 512)
def forward(self, x):
if x.dim() == 5:
x = x.squeeze(-1) # β
Remove extra dimension if exists
x = self.conv1(x)
x = self.conv2(x)
x = x.squeeze(-1) # β
Ensure shape is [batch, channels, length] before Conv1d
x = self.conv3(x) # β
Now Conv1d receives correct input shape [batch, channels, length]
return self.fc(x.max(dim=-1)[0]) # β
Ensure correct pooling
class PointCloudEncoder(nn.Module):
"""Point Cloud Encoder (pc_enc)"""
def __init__(self):
super(PointCloudEncoder, self).__init__()
self.transform_net = TransformNet()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(512, 256, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
nn.BatchNorm2d(256),
nn.ReLU()
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
nn.BatchNorm2d(128),
nn.ReLU()
),
nn.Sequential(
nn.Conv1d(128, 64, kernel_size=1), # β
Conv1d to match `[64, 128, 1]`
nn.BatchNorm1d(64),
nn.ReLU()
)
])
self.lin_global = nn.Linear(64, 128)
def forward(self, x):
x = self.transform_net(x)
for i, conv in enumerate(self.convs):
if i < 2:
x = conv(x) # β
Conv2d keeps 4D
else:
x = x.squeeze(-1) # β
Ensure shape is [batch, channels, length] before Conv1d
x = conv(x) # β
Conv1d now works with the correct input
return self.lin_global(x.max(dim=-1)[0]) # β
Fix pooling
class GarmentEncoder(nn.Module):
"""Garment Feature Encoder (garm_enc)"""
def __init__(self, num_classes=18, feature_dim=64):
super(GarmentEncoder, self).__init__()
self.garm_embedding = nn.Parameter(torch.randn(num_classes, feature_dim))
self.attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4)
self.ff = nn.Sequential(
nn.Linear(feature_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.norm = nn.LayerNorm(64)
def forward(self, x, clothing_classes):
garment_features = self.garm_embedding[clothing_classes]
attn_output, _ = self.attn(x, garment_features, garment_features)
return self.norm(self.ff(attn_output))
class SegmentationDecoder(nn.Module):
"""Segmentation Decoder (segm_dec)"""
def __init__(self, input_dim=192, num_classes=18):
super(SegmentationDecoder, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, num_classes)
)
def forward(self, x):
return self.layers(x)
class CloseNet(nn.Module):
"""Complete CloSe-Net Model"""
def __init__(self):
super(CloseNet, self).__init__()
self.pc_enc = PointCloudEncoder()
self.garm_enc = GarmentEncoder()
self.segm_dec = SegmentationDecoder()
def forward(self, point_cloud, clothing_classes):
pc_features = self.pc_enc(point_cloud)
garm_features = self.garm_enc(pc_features, clothing_classes)
features = torch.cat((pc_features, garm_features), dim=1)
return self.segm_dec(features)
# Load Pretrained Model
model_path = "model_arch/closenet.pth"
model = CloseNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
model.eval()
def segment_dress(image):
"""Detect and segment the dress from the image."""
img = Image.fromarray(image).convert("RGB")
img = np.array(img).transpose(2, 0, 1) # Convert to [C, H, W] β [3, H, W]
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize to [1, 3, H, W]
# β
Duplicate channels to match the expected 6-channel input
img = torch.cat((img, img), dim=1) # Convert [1, 3, H, W] β [1, 6, H, W]
with torch.no_grad():
mask = model(img, clothing_classes=torch.arange(18)) # β
Correct input shape
mask = mask.squeeze().numpy()
mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask
return mask
def change_color(image, color):
"""Change dress color based on segmentation."""
mask = segment_dress(image)
color_bgr = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (4, 2, 0)) # Convert HEX to BGR
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image_bgr[mask == 255] = color_bgr # Apply new color where mask is present
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
return image_rgb
# Gradio Interface
interface = gr.Interface(
fn=change_color,
inputs=[
gr.Image(type="numpy", label="Upload a dress image"),
gr.ColorPicker(label="Choose color")
],
outputs=gr.Image(type="numpy", label="Color-changed image"),
title="AI Dress Color Changer",
description="Upload an image of a dress and change its color using AI segmentation."
)
if __name__ == "__main__":
interface.launch() |