File size: 2,020 Bytes
d728dec f386c64 f64a081 d728dec f386c64 54ba009 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec f64a081 d728dec dc53c6b f386c64 d728dec f64a081 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import math
import os
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
def pad_to_divisible(img, div=32):
h, w, _ = img.shape
new_h = math.ceil(h / div) * div
new_w = math.ceil(w / div) * div
pad_bottom = new_h - h
pad_right = new_w - w
padded = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
return padded
device = torch.device("cpu")
print("Using device:", device)
model_path = "best_unet.pth"
if os.path.exists(model_path):
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=3,
classes=1
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("β
Model loaded successfully.")
else:
raise FileNotFoundError(f"β Model file not found at: {model_path}")
def predict(image):
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
original = image.copy()
padded = pad_to_divisible(original, div=32)
normalized = padded.astype(np.float32) / 255.0
tensor = torch.from_numpy(normalized).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor)
mask = (torch.sigmoid(output) > 0.3).float().cpu().squeeze().numpy()
overlay = padded.astype(np.float32)
mask_bool = mask > 0.5
red = np.array([255, 0, 0], dtype=np.float32)
overlay[mask_bool] = (1 - 0.5) * overlay[mask_bool] + 0.5 * red
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return padded, mask, overlay
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Face Image"),
outputs=[
gr.Image(label="Padded Input"),
gr.Image(label="Predicted Mask"),
gr.Image(label="Overlay on Image")
],
title="Wrinkle Segmentation",
description="Upload a face image to see wrinkle regions detected"
)
demo.launch()
|