import streamlit as st import sys import os import shutil import time from datetime import datetime import csv import cv2 import numpy as np from PIL import Image import torch from hf_hub_download import hf_hub_download sys.path.append('Utils') sys.path.append('model') from model.CBAM.reunet_cbam import reunet_cbam from model.transform import transforms from model.unet import UNET from Utils.area import pixel_to_sqft, process_and_overlay_image from Utils.convert import read_pansharpened_rgb BASE_DIR = hf_hub_download(repo_id="Pavan2k4/Building_area", repo_type="space") # Define subdirectories UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images") MASK_DIR = os.path.join(BASE_DIR, "generated_masks") PATCHES_DIR = os.path.join(BASE_DIR, "patches") PRED_PATCHES_DIR = os.path.join(BASE_DIR, "pred_patches") CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv") # Create directories for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]: os.makedirs(directory, exist_ok=True) # Load model @st.cache_resource def load_model(): model = reunet_cbam() model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) model.eval() return model model = load_model() def predict(image): with torch.no_grad(): output = model(image.unsqueeze(0)) return output.squeeze().cpu().numpy() def split_image(image, patch_size=512): h, w, _ = image.shape patches = [] for y in range(0, h, patch_size): for x in range(0, w, patch_size): patch = image[y:min(y+patch_size, h), x:min(x+patch_size, w)] patches.append((f"patch_{y}_{x}.png", patch)) return patches def merge(patch_folder, dest_image='out.png', image_shape=None): merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8) for filename in os.listdir(patch_folder): if filename.endswith(".png"): patch_path = os.path.join(patch_folder, filename) patch = cv2.imread(patch_path) patch_height, patch_width, _ = patch.shape # Extract patch coordinates from filename parts = filename.split("_") x, y = None, None for part in parts: if part.endswith(".png"): x = int(part.split(".")[0]) elif part.isdigit(): y = int(part) if x is None or y is None: raise ValueError(f"Invalid filename: {filename}") # Check if patch fits within image boundaries if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]: # Adjust patch position to fit within image boundaries if x + patch_width > image_shape[1]: x = image_shape[1] - patch_width if y + patch_height > image_shape[0]: y = image_shape[0] - patch_height # Merge patch into the main image merged[y:y+patch_height, x:x+patch_width, :] = patch cv2.imwrite(dest_image, merged) return merged def process_large_image(model, image_path, patch_size=512): # Read the image img = cv2.imread(image_path) if img is None: raise ValueError(f"Failed to read image from {image_path}") h, w, _ = img.shape st.write(f"Processing image of size {w}x{h}") # Split the image into patches patches = split_image(img, patch_size) # Process each patch for filename, patch in patches: patch_pil = Image.fromarray(cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)) patch_transformed = transforms(patch_pil) prediction = predict(patch_transformed) mask = (prediction > 0.5).astype(np.uint8) * 255 # Save the mask patch mask_filepath = os.path.join(PRED_PATCHES_DIR, filename) cv2.imwrite(mask_filepath, mask) # Merge the predicted patches merged_mask = merge(PRED_PATCHES_DIR, dest_image='merged_mask.png', image_shape=img.shape) return merged_mask def log_image_details(image_id, image_filename, mask_filename): file_exists = os.path.exists(CSV_LOG_PATH) current_time = datetime.now() date = current_time.strftime('%Y-%m-%d') time = current_time.strftime('%H:%M:%S') with open(CSV_LOG_PATH, mode='a', newline='') as file: writer = csv.writer(file) if not file_exists: writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) # Get the next S.No if file_exists: with open(CSV_LOG_PATH, mode='r') as f: reader = csv.reader(f) sno = sum(1 for row in reader) else: sno = 1 writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) def upload_page(): if 'file_uploaded' not in st.session_state: st.session_state.file_uploaded = False if 'filename' not in st.session_state: st.session_state.filename = None if 'mask_filename' not in st.session_state: st.session_state.mask_filename = None image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) if image is not None and not st.session_state.file_uploaded: try: bytes_data = image.getvalue() timestamp = int(time.time()) original_filename = image.name file_extension = os.path.splitext(original_filename)[1].lower() if file_extension in ['.tiff', '.tif']: filename = f"image_{timestamp}.tif" converted_filename = f"image_{timestamp}_converted.png" else: filename = f"image_{timestamp}.png" converted_filename = filename filepath = os.path.join(UPLOAD_DIR, filename) converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) with open(filepath, "wb") as f: f.write(bytes_data) st.success(f"Image saved to {filepath}") # Check if the uploaded file is a GeoTIFF if file_extension in ['.tiff', '.tif']: st.info('Processing GeoTIFF image...') rgb_image = read_pansharpened_rgb(filepath) cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') img = Image.open(converted_filepath) else: img = Image.open(filepath) st.image(img, caption='Uploaded Image', use_column_width=True) st.success(f'Image processed and saved as {converted_filename}') # Store the full path of the converted image st.session_state.filename = converted_filename # Process the image st.write("Processing image...") with st.spinner('Analyzing...'): full_mask = process_large_image(model, converted_filepath) # Save the full mask mask_filename = f"mask_{timestamp}.png" mask_filepath = os.path.join(MASK_DIR, mask_filename) cv2.imwrite(mask_filepath, full_mask) st.session_state.mask_filename = mask_filename st.success("Image processed successfully") # Log image details log_image_details(timestamp, converted_filename, mask_filename) st.session_state.file_uploaded = True # Clean up temporary patch files st.info('Cleaning up temporary files...') for file in os.listdir(PRED_PATCHES_DIR): os.remove(os.path.join(PRED_PATCHES_DIR, file)) st.success('Temporary files cleaned up') except Exception as e: st.error(f"An error occurred: {str(e)}") st.error("Please check the logs for more details.") print(f"Error in upload_page: {str(e)}") # This will appear in the Streamlit logs if st.session_state.file_uploaded and st.button('View result'): if st.session_state.filename is None: st.error("Please upload an image before viewing the result.") else: st.success('Image analyzed') st.session_state.page = 'result' st.rerun() def result_page(): st.title('Analysis Result') if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: st.error("No image or mask file found. Please upload and process an image first.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() return col1, col2 = st.columns(2) # Display original image original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) if os.path.exists(original_img_path): original_img = Image.open(original_img_path) col1.image(original_img, caption='Original Image', use_column_width=True) else: col1.error(f"Original image file not found: {original_img_path}") # Display predicted mask mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename) if os.path.exists(mask_path): mask = Image.open(mask_path) col2.image(mask, caption='Predicted Mask', use_column_width=True) else: col2.error(f"Predicted mask file not found: {mask_path}") st.subheader("Overlay with Area of Buildings (sqft)") # Display overlayed image if os.path.exists(original_img_path) and os.path.exists(mask_path): original_np = cv2.imread(original_img_path) mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Ensure mask is binary _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) # Resize mask to match original image size if necessary if original_np.shape[:2] != mask_np.shape[:2]: mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) # Process and overlay image overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') st.image(overlay_img, caption='Overlay Image', use_column_width=True) else: st.error("Image or mask file not found for overlay.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() def main(): st.title('Building area estimation') if 'page' not in st.session_state: st.session_state.page = 'upload' if st.session_state.page == 'upload': upload_page() elif st.session_state.page == 'result': result_page() if __name__ == '__main__': main()