File size: 3,063 Bytes
8395863 8271835 776dd3c 8271835 36945ed fc1d3e9 36945ed 8395863 fc1d3e9 8395863 776dd3c 8271835 8395863 8271835 8395863 776dd3c 8395863 2ed70b2 8271835 776dd3c d870d2d 776dd3c 8271835 621f2db 8271835 83d61e1 776dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import streamlit as st
import cv2
import numpy as np
import torch
from torchvision import transforms, models
from PIL import Image
from TranSalNet_Res import TranSalNet
from tqdm import tqdm
import torch.nn as nn
from utils.data_process import preprocess_img, postprocess_img
device = torch.device('cpu')
model = TranSalNet()
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
model.to(device)
model.eval()
st.title('Saliency Detection App')
st.write('Upload an image for saliency detection:')
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Detect Saliency'):
img = image.resize((384, 288))
img = np.array(img) / 100.
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
img = torch.from_numpy(img)
img = img.type(torch.FloatTensor).to(device)
pred_saliency = model(img)
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())
colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_OCEAN)
original_img = np.array(image)
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
alpha = 0.7
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
# Find all contours
contours, _ = cv2.findContours(np.uint8(pred_saliency.squeeze().detach().numpy() * 255), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
# Apply dilation
kernel = np.ones((5,5),np.uint8)
dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
# Find contours on dilated image
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
font = cv2.FONT_HERSHEY_SIMPLEX
label = 1
for contour in contours:
# Get bounding box for contour
x, y, w, h = cv2.boundingRect(contour)
# Calculate center of bounding box
center_x = x + w // 2
center_y = y + h // 2
# Find point on contour closest to center of bounding box
distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
min_index = np.argmin(distances)
closest_point = tuple(contour[min_index][0])
# Place label at closest point on contour
cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
label += 1
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|