import streamlit as st import cv2 import numpy as np import torch from torchvision import transforms, utils, models from PIL import Image from TranSalNet_Res import TranSalNet from tqdm import tqdm import torch.nn as nn from utils.data_process import preprocess_img, postprocess_img # Load the model and set the device device = torch.device('cpu') model = TranSalNet() model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) model.to(device) model.eval() # Define Streamlit app st.title('Saliency Detection App') st.write('Upload an image for saliency detection:') uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption='Uploaded Image', use_column_width=True) # Check if the user clicks a button if st.button('Detect Saliency'): # Create a blue background image with the same dimensions as the original image blue_background = np.zeros_like(np.array(image)) blue_background[:] = (255, 0, 0) # Set the background to blue (in BGR format) # Preprocess the image img = image.resize((384, 288)) img = np.array(img) / 255. img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) img = torch.from_numpy(img) img = img.type(torch.FloatTensor).to(device) # Get saliency prediction pred_saliency = model(img) # Convert the result back to a PIL image toPIL = transforms.ToPILImage() pic = toPIL(pred_saliency.squeeze()) # Colorize the grayscale prediction colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET) # Ensure the colorized image has the same dimensions as the original image original_img = np.array(image) colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0)) # Create an empty label map for ranking based on area label_map = np.zeros_like(colorized_img) # Overlay the labels on the blended image font = cv2.FONT_HERSHEY_SIMPLEX contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for i, contour in enumerate(contours): M = cv2.moments(contour) if M["m00"] == 0: continue center_x = int(M["m10"] / M["m00"]) center_y = int(M["m01"] / M["m00"]) cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) # Blend the colorized image with the blue background alpha = 0.7 # Adjust the alpha value to control blending strength blended_img = cv2.addWeighted(blue_background, 1 - alpha, colorized_img, alpha, 0) # Display the final result st.image(blended_img, caption='Blended Image with Labels', use_column_width=True) # Save the final result cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) st.success('Saliency detection complete. Result saved as "example/result15.png"') st.write('Finished, check the result at: example/result15.png')