File size: 3,063 Bytes
8395863
 
 
8271835
776dd3c
8271835
36945ed
fc1d3e9
 
 
 
 
36945ed
8395863
 
fc1d3e9
8395863
 
 
 
 
 
 
 
 
 
 
776dd3c
8271835
 
 
8395863
8271835
8395863
 
 
 
776dd3c
8395863
 
2ed70b2
8271835
776dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
d870d2d
776dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8271835
621f2db
8271835
 
83d61e1
776dd3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img

device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()

st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Detect Saliency'):
        img = image.resize((384, 288))
        img = np.array(img) / 100.
        img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
        img = torch.from_numpy(img)
        img = img.type(torch.FloatTensor).to(device)

        pred_saliency = model(img)

        toPIL = transforms.ToPILImage()
        pic = toPIL(pred_saliency.squeeze())

        colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)

        original_img = np.array(image)
        colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))

        alpha = 0.7  
        blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)

        # Find all contours
        contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
    
        # Apply dilation
        kernel = np.ones((5,5),np.uint8)
        dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
    
        # Find contours on dilated image
        contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
        font = cv2.FONT_HERSHEY_SIMPLEX
        label = 1
        for contour in contours:
            # Get bounding box for contour
            x, y, w, h = cv2.boundingRect(contour)
    
            # Calculate center of bounding box
            center_x = x + w // 2
            center_y = y + h // 2
    
            # Find point on contour closest to center of bounding box
            distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
            min_index = np.argmin(distances)
            closest_point = tuple(contour[min_index][0])
    
            # Place label at closest point on contour
            cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
    
            label += 1


        st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)

        cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
        st.success('Saliency detection complete. Result saved as "example/result15.png".')