import streamlit as st import sys import os import shutil import time from datetime import datetime import csv import cv2 import numpy as np from PIL import Image import torch # Adjust import paths as needed sys.path.append('Utils') sys.path.append('model') from model.CBAM.reunet_cbam import reunet_cbam from model.transform import transforms from model.unet import UNET from Utils.area import pixel_to_sqft, process_and_overlay_image from split_merge import split, merge from Utils.convert import read_pansharpened_rgb # Define base directory for Hugging Face Spaces BASE_DIR = "/home/user" # Define subdirectories UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images") MASK_DIR = os.path.join(BASE_DIR, "generated_masks") PATCHES_DIR = os.path.join(BASE_DIR, "patches") PRED_PATCHES_DIR = os.path.join(BASE_DIR, "pred_patches") CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv") # Create directories for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]: os.makedirs(directory, exist_ok=True) # Load model @st.cache_resource def load_model(): model = reunet_cbam() model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) model.eval() return model model = load_model() def predict(image): with torch.no_grad(): output = model(image.unsqueeze(0)) return output.squeeze().cpu().numpy() def log_image_details(image_id, image_filename, mask_filename): file_exists = os.path.exists(CSV_LOG_PATH) current_time = datetime.now() date = current_time.strftime('%Y-%m-%d') time = current_time.strftime('%H:%M:%S') with open(CSV_LOG_PATH, mode='a', newline='') as file: writer = csv.writer(file) if not file_exists: writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) # Get the next S.No if file_exists: with open(CSV_LOG_PATH, mode='r') as f: reader = csv.reader(f) sno = sum(1 for row in reader) else: sno = 1 writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) def upload_page(): if 'file_uploaded' not in st.session_state: st.session_state.file_uploaded = False if 'filename' not in st.session_state: st.session_state.filename = None if 'mask_filename' not in st.session_state: st.session_state.mask_filename = None image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) if image is not None and not st.session_state.file_uploaded: try: bytes_data = image.getvalue() timestamp = int(time.time()) original_filename = image.name file_extension = os.path.splitext(original_filename)[1].lower() if file_extension in ['.tiff', '.tif']: filename = f"image_{timestamp}.tif" converted_filename = f"image_{timestamp}_converted.png" else: filename = f"image_{timestamp}.png" converted_filename = filename filepath = os.path.join(UPLOAD_DIR, filename) converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) with open(filepath, "wb") as f: f.write(bytes_data) st.success(f"Image saved to {filepath}") # Check if the uploaded file is a GeoTIFF if file_extension in ['.tiff', '.tif']: st.info('Processing GeoTIFF image...') rgb_image = read_pansharpened_rgb(filepath) cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') img = Image.open(converted_filepath) else: img = Image.open(filepath) st.image(img, caption='Uploaded Image', use_column_width=True) st.success(f'Image processed and saved as {converted_filename}') # Store the full path of the converted image st.session_state.filename = converted_filename # Convert image to numpy array img_array = np.array(img) # Check if image shape is more than 650x650 if img_array.shape[0] > 650 or img_array.shape[1] > 650: # Split image into patches split(converted_filepath, patch_size=512) # Display buffer while analyzing with st.spinner('Analyzing...'): # Predict on each patch for patch_filename in os.listdir(PATCHES_DIR): if patch_filename.endswith(".png"): patch_path = os.path.join(PATCHES_DIR, patch_filename) patch_img = Image.open(patch_path) patch_tr_img = transforms(patch_img) prediction = predict(patch_tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 mask_filename = f"mask_{patch_filename}" mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename) Image.fromarray(mask).save(mask_filepath) # Merge predicted patches merged_mask_filename = f"mask_{timestamp}.png" merged_mask_path = os.path.join(MASK_DIR, merged_mask_filename) merge(PRED_PATCHES_DIR, merged_mask_path, img_array.shape) # Save merged mask st.session_state.mask_filename = merged_mask_filename # Clean up temporary patch files st.info('Cleaning up temporary files...') shutil.rmtree(PATCHES_DIR) shutil.rmtree(PRED_PATCHES_DIR) os.makedirs(PATCHES_DIR) # Recreate empty folders os.makedirs(PRED_PATCHES_DIR) st.success('Temporary files cleaned up') else: # Predict on whole image st.session_state.tr_img = transforms(img) prediction = predict(st.session_state.tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 mask_filename = f"mask_{timestamp}.png" mask_filepath = os.path.join(MASK_DIR, mask_filename) Image.fromarray(mask).save(mask_filepath) st.session_state.mask_filename = mask_filename st.session_state.file_uploaded = True except Exception as e: st.error(f"An error occurred: {str(e)}") st.error("Please check the logs for more details.") print(f"Error in upload_page: {str(e)}") # This will appear in the Streamlit logs if st.session_state.file_uploaded and st.button('View result'): if st.session_state.filename is None: st.error("Please upload an image before viewing the result.") else: st.success('Image analyzed') st.session_state.page = 'result' st.rerun() def result_page(): st.title('Analysis Result') if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: st.error("No image or mask file found. Please upload and process an image first.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() return col1, col2 = st.columns(2) # Display original image original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) if os.path.exists(original_img_path): original_img = Image.open(original_img_path) col1.image(original_img, caption='Original Image', use_column_width=True) else: col1.error(f"Original image file not found: {original_img_path}") # Display predicted mask mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename) if os.path.exists(mask_path): mask = Image.open(mask_path) col2.image(mask, caption='Predicted Mask', use_column_width=True) else: col2.error(f"Predicted mask file not found: {mask_path}") st.subheader("Overlay with Area of Buildings (sqft)") # Display overlayed image if os.path.exists(original_img_path) and os.path.exists(mask_path): original_np = cv2.imread(original_img_path) mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Ensure mask is binary _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) # Resize mask to match original image size if necessary if original_np.shape[:2] != mask_np.shape[:2]: mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) # Process and overlay image overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') st.image(overlay_img, caption='Overlay Image', use_column_width=True) else: st.error("Image or mask file not found for overlay.") if st.button('Back to Upload'): shutil.rmtree(PATCHES_DIR) shutil.rmtree(PRED_PATCHES_DIR) st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() def main(): st.title('Building area estimation') if 'page' not in st.session_state: st.session_state.page = 'upload' if st.session_state.page == 'upload': upload_page() elif st.session_state.page == 'result': result_page() if __name__ == '__main__': main()