import streamlit as st import torch import cv2 from PIL import Image import numpy as np from torchvision import transforms from TranSalNet_Res import TranSalNet # Make sure TranSalNet is accessible from your Streamlit app # Load the model and set the device model = TranSalNet() model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) model.eval() # Set the model to evaluation mode device = torch.device('cpu') model.to(device) # Define Streamlit app st.title('Saliency Detection App') st.write('Upload an image for saliency detection:') uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption='Uploaded Image', use_column_width=True) # Check if the user clicks a button if st.button('Detect Saliency'): # Preprocess the image img = image.resize((384, 288)) img = np.array(img) / 255. img = np.transpose(img, (2, 0, 1)) img = torch.from_numpy(img).unsqueeze(0).float() img = img.to(device) # Get saliency prediction with torch.no_grad(): pred_saliency = model(img) # Convert the result back to a PIL image toPIL = transforms.ToPILImage() pic = toPIL(pred_saliency.squeeze()) # Colorize the grayscale prediction colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET) # Ensure the colorized image has the same dimensions as the original image original_img = np.array(image) colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0])) # You can add more post-processing here if needed # Display the final result st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True) st.write('Finished!')