coloring / app.py
gaur3009's picture
Update app.py
be3d161 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image
import cv2
class TransformNet(nn.Module):
"""Transformation Network for PointCloud Encoding"""
def __init__(self, input_dim=6): # βœ… Ensure input has 6 channels
super(TransformNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv1d(128, 1024, kernel_size=1), # βœ… Conv1d to match `[1024, 128, 1]`
nn.BatchNorm1d(1024),
nn.ReLU()
)
self.fc = nn.Linear(1024, 512)
def forward(self, x):
if x.dim() == 5:
x = x.squeeze(-1) # βœ… Remove extra dimension if exists
x = self.conv1(x)
x = self.conv2(x)
x = x.squeeze(-1) # βœ… Ensure shape is [batch, channels, length] before Conv1d
x = self.conv3(x) # βœ… Now Conv1d receives correct input shape [batch, channels, length]
return self.fc(x.max(dim=-1)[0]) # βœ… Ensure correct pooling
class PointCloudEncoder(nn.Module):
"""Point Cloud Encoder (pc_enc)"""
def __init__(self):
super(PointCloudEncoder, self).__init__()
self.transform_net = TransformNet()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(512, 256, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
nn.BatchNorm2d(256),
nn.ReLU()
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
nn.BatchNorm2d(128),
nn.ReLU()
),
nn.Sequential(
nn.Conv1d(128, 64, kernel_size=1), # βœ… Conv1d to match `[64, 128, 1]`
nn.BatchNorm1d(64),
nn.ReLU()
)
])
self.lin_global = nn.Linear(64, 128)
def forward(self, x):
x = self.transform_net(x)
for i, conv in enumerate(self.convs):
if i < 2:
x = conv(x) # βœ… Conv2d keeps 4D
else:
x = x.squeeze(-1) # βœ… Ensure shape is [batch, channels, length] before Conv1d
x = conv(x) # βœ… Conv1d now works with the correct input
return self.lin_global(x.max(dim=-1)[0]) # βœ… Fix pooling
class GarmentEncoder(nn.Module):
"""Garment Feature Encoder (garm_enc)"""
def __init__(self, num_classes=18, feature_dim=64):
super(GarmentEncoder, self).__init__()
self.garm_embedding = nn.Parameter(torch.randn(num_classes, feature_dim))
self.attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4)
self.ff = nn.Sequential(
nn.Linear(feature_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.norm = nn.LayerNorm(64)
def forward(self, x, clothing_classes):
garment_features = self.garm_embedding[clothing_classes]
attn_output, _ = self.attn(x, garment_features, garment_features)
return self.norm(self.ff(attn_output))
class SegmentationDecoder(nn.Module):
"""Segmentation Decoder (segm_dec)"""
def __init__(self, input_dim=192, num_classes=18):
super(SegmentationDecoder, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, num_classes)
)
def forward(self, x):
return self.layers(x)
class CloseNet(nn.Module):
"""Complete CloSe-Net Model"""
def __init__(self):
super(CloseNet, self).__init__()
self.pc_enc = PointCloudEncoder()
self.garm_enc = GarmentEncoder()
self.segm_dec = SegmentationDecoder()
def forward(self, point_cloud, clothing_classes):
pc_features = self.pc_enc(point_cloud)
garm_features = self.garm_enc(pc_features, clothing_classes)
features = torch.cat((pc_features, garm_features), dim=1)
return self.segm_dec(features)
# Load Pretrained Model
model_path = "model_arch/closenet.pth"
model = CloseNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
model.eval()
def segment_dress(image):
"""Detect and segment the dress from the image."""
img = Image.fromarray(image).convert("RGB")
img = np.array(img).transpose(2, 0, 1) # Convert to [C, H, W] β†’ [3, H, W]
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize to [1, 3, H, W]
# βœ… Duplicate channels to match the expected 6-channel input
img = torch.cat((img, img), dim=1) # Convert [1, 3, H, W] β†’ [1, 6, H, W]
with torch.no_grad():
mask = model(img, clothing_classes=torch.arange(18)) # βœ… Correct input shape
mask = mask.squeeze().numpy()
mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask
return mask
def change_color(image, color):
"""Change dress color based on segmentation."""
mask = segment_dress(image)
color_bgr = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (4, 2, 0)) # Convert HEX to BGR
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image_bgr[mask == 255] = color_bgr # Apply new color where mask is present
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
return image_rgb
# Gradio Interface
interface = gr.Interface(
fn=change_color,
inputs=[
gr.Image(type="numpy", label="Upload a dress image"),
gr.ColorPicker(label="Choose color")
],
outputs=gr.Image(type="numpy", label="Color-changed image"),
title="AI Dress Color Changer",
description="Upload an image of a dress and change its color using AI segmentation."
)
if __name__ == "__main__":
interface.launch()