|
import gradio as gr |
|
import math |
|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import segmentation_models_pytorch as smp |
|
|
|
def pad_to_divisible(img, div=32): |
|
h, w, _ = img.shape |
|
new_h = math.ceil(h / div) * div |
|
new_w = math.ceil(w / div) * div |
|
pad_bottom = new_h - h |
|
pad_right = new_w - w |
|
padded = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
|
return padded |
|
|
|
|
|
device = torch.device("cpu") |
|
print("Using device:", device) |
|
|
|
|
|
model_path = "best_unet.pth" |
|
if os.path.exists(model_path): |
|
model = smp.Unet( |
|
encoder_name="resnet34", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=1 |
|
) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
print("β
Model loaded successfully.") |
|
else: |
|
raise FileNotFoundError(f"β Model file not found at: {model_path}") |
|
|
|
def predict(image): |
|
if image.shape[2] == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
|
original = image.copy() |
|
padded = pad_to_divisible(original, div=32) |
|
|
|
normalized = padded.astype(np.float32) / 255.0 |
|
tensor = torch.from_numpy(normalized).permute(2, 0, 1).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(tensor) |
|
mask = (torch.sigmoid(output) > 0.3).float().cpu().squeeze().numpy() |
|
|
|
overlay = padded.astype(np.float32) |
|
mask_bool = mask > 0.5 |
|
red = np.array([255, 0, 0], dtype=np.float32) |
|
overlay[mask_bool] = (1 - 0.5) * overlay[mask_bool] + 0.5 * red |
|
overlay = np.clip(overlay, 0, 255).astype(np.uint8) |
|
|
|
return padded, mask, overlay |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="numpy", label="Upload Face Image"), |
|
outputs=[ |
|
gr.Image(label="Padded Input"), |
|
gr.Image(label="Predicted Mask"), |
|
gr.Image(label="Overlay on Image") |
|
], |
|
title="Wrinkle Segmentation", |
|
description="Upload a face image to see wrinkle regions detected" |
|
) |
|
|
|
demo.launch() |
|
|