import streamlit as st import cv2 import numpy as np import torch from torchvision import transforms, models from PIL import Image from TranSalNet_Res import TranSalNet import torch.nn as nn from utils.data_process import preprocess_img, postprocess_img device = torch.device('cpu') model = TranSalNet() model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu'))) model.to(device) model.eval() def count_and_label_red_patches(heatmap, threshold=200): red_mask = heatmap[:, :, 2] > threshold _, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8) num_red_patches = labels.max() original_image = np.array(image) for i in range(1, num_red_patches + 1): patch_mask = (labels == i) patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2) radius = 20 # Adjust the following variable to manage the circle image circle_color = (0, 0, 0) # The circle is black adjust the following to change the color cv2.circle(original_image, (patch_centroid_x, patch_centroid_y), radius, circle_color, -1) # Draw the circle # Lines code for j in range(i + 1, num_red_patches + 1): patch_mask_j = (labels == j) patch_centroid_x_j, patch_centroid_y_j = int(stats[j, cv2.CC_STAT_LEFT] + stats[j, cv2.CC_STAT_WIDTH] / 2), int(stats[j, cv2.CC_STAT_TOP] + stats[j, cv2.CC_STAT_HEIGHT] / 2) line_color = (0, 0, 0) # Ajdust the following to manage the line color cv2.line(original_image, (patch_centroid_x, patch_centroid_y), (patch_centroid_x_j, patch_centroid_y_j), line_color, 2) # Line font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 1 font_color = (255, 255, 255) line_type = cv2.LINE_AA cv2.putText(original_image, str(i), (patch_centroid_x - 10, patch_centroid_y + 10), font, font_scale, font_color, 2, line_type) return original_image, num_red_patches st.title('Saliency Detection App') st.write('Upload an image for saliency detection:') uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Detect Saliency'): img = image.resize((384, 288)) img = np.array(img) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Convert to BGR color space img = np.array(img) / 255. img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) img = torch.from_numpy(img) img = img.type(torch.FloatTensor).to(device) pred_saliency = model(img).squeeze().detach().numpy() heatmap = (pred_saliency * 255).astype(np.uint8) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Use a blue colormap (JET) heatmap = cv2.resize(heatmap, (image.width, image.height)) enhanced_image = np.array(image) b, g, r = cv2.split(enhanced_image) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) b_enhanced = clahe.apply(b) enhanced_image = cv2.merge((b_enhanced, g, r)) alpha = 0.7 blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0) original_image, num_red_patches = count_and_label_red_patches(heatmap) st.image(original_image, caption=f'Image with {num_red_patches} Red Patches', use_column_width=True, channels='RGB') st.image(blended_img, caption='Blended Image', use_column_width=True, channels='BGR') # Create a dir with the name example to save cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200]) st.success('Saliency detection complete. Result saved as "example/result15.png".')