Rihem02 commited on
Commit
d728dec
·
1 Parent(s): d31fc04

Add application file

Browse files
Files changed (1) hide show
  1. wrinkles.py +78 -0
wrinkles.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import math
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import segmentation_models_pytorch as smp
8
+
9
+ def pad_to_divisible(img, div=32):
10
+ h, w, _ = img.shape
11
+ new_h = math.ceil(h / div) * div
12
+ new_w = math.ceil(w / div) * div
13
+ pad_bottom = new_h - h
14
+ pad_right = new_w - w
15
+ padded = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
16
+ return padded
17
+
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print("Using device:", device)
21
+
22
+ model_path = "best_unet_model_complete.pth"
23
+ if os.path.exists(model_path):
24
+ loaded_model = torch.load(model_path, map_location=device)
25
+ loaded_model.eval()
26
+ print("Loaded complete model from", model_path)
27
+ else:
28
+ raise FileNotFoundError(f"Model file not found at {model_path}")
29
+
30
+ def predict(image):
31
+ """
32
+ Takes an input image (as a NumPy array), pads it, performs inference,
33
+ and returns:
34
+ - the padded input image,
35
+ - the predicted mask (grayscale), and
36
+ - the original image with a red overlay on the predicted regions.
37
+ """
38
+
39
+ if image.shape[2] == 4:
40
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
41
+
42
+
43
+ external_image_rgb = image.copy()
44
+
45
+ external_image_padded = pad_to_divisible(external_image_rgb, div=32)
46
+
47
+ external_image_norm = external_image_padded.astype(np.float32) / 255.0
48
+ external_tensor = torch.from_numpy(external_image_norm).permute(2, 0, 1)
49
+ external_tensor = external_tensor.unsqueeze(0).to(device)
50
+
51
+ with torch.no_grad():
52
+ output = loaded_model(external_tensor)
53
+ pred_mask = (torch.sigmoid(output) > 0.3).float()
54
+ pred_mask_np = pred_mask.cpu().squeeze().numpy()
55
+
56
+
57
+ overlay_image = external_image_padded.astype(np.float32)
58
+ mask_bool = pred_mask_np > 0.5
59
+ red_color = np.array([255, 0, 0], dtype=np.float32)
60
+ alpha = 0.5
61
+ overlay_image[mask_bool] = (1 - alpha) * overlay_image[mask_bool] + alpha * red_color
62
+ overlay_image = np.clip(overlay_image, 0, 255).astype(np.uint8)
63
+
64
+ return external_image_padded, pred_mask_np, overlay_image
65
+
66
+ demo = gr.Interface(
67
+ fn=predict,
68
+ inputs=gr.Image(type="numpy", label="Upload Image"),
69
+ outputs=[
70
+ gr.Image(label="Padded Image"),
71
+ gr.Image(label="Predicted Mask"),
72
+ gr.Image(label="Overlay (Predicted Regions)")
73
+ ],
74
+ title="Wrinkle Segmentation",
75
+ description="Upload an image to see wrinkle segmentation. The app displays the padded image, the predicted mask, and an overlay of the predicted regions in red."
76
+ )
77
+
78
+ demo.launch()