import streamlit as st import sys sys.path.append('Utils') sys.path.append('model') import torch from model.CBAM.reunet_cbam import reunet_cbam import cv2 from PIL import Image from model.transform import transforms import numpy as np from model.unet import UNET from Utils.area import pixel_to_sqft, process_and_overlay_image import matplotlib.pyplot as plt import time import os import csv from datetime import datetime from Utils.split_merge import split, merge from Utils.convert import convert_gtiff_to_8bit import shutil # Define directories UPLOAD_DIR = "data/uploaded_images/" MASK_DIR = "data/generated_masks/" PATCHES_DIR = 'data/Patches/' PRED_PATCHES_DIR = 'data/Patch_pred/' CSV_LOG_PATH = "image_log.csv" # Create directories for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]: os.makedirs(directory, exist_ok=True) # Load model model = reunet_cbam() model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) model.eval() def predict(image): with torch.no_grad(): output = model(image.unsqueeze(0)) return output.squeeze().cpu().numpy() def log_image_details(image_id, image_filename, mask_filename): file_exists = os.path.exists(CSV_LOG_PATH) current_time = datetime.now() date = current_time.strftime('%Y-%m-%d') time = current_time.strftime('%H:%M:%S') with open(CSV_LOG_PATH, mode='a', newline='') as file: writer = csv.writer(file) if not file_exists: writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) sno = sum(1 for row in open(CSV_LOG_PATH)) if file_exists else 1 writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) def reset_state(): st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.session_state.tr_img = None if 'page' in st.session_state: del st.session_state.page def process_image(image, timestamp): filename = f"image_{timestamp}{os.path.splitext(image.name)[1]}" filepath = os.path.join(UPLOAD_DIR, filename) with open(filepath, "wb") as f: f.write(image.getvalue()) if filename.lower().endswith(('.tiff', '.tif')): st.info('Processing GeoTIFF image...') convert_gtiff_to_8bit(filepath) st.success('GeoTIFF converted to 8-bit image') return filename, filepath def predict_image(img_array, filename, timestamp): if img_array.shape[0] > 650 or img_array.shape[1] > 650: split(os.path.join(UPLOAD_DIR, filename), patch_size=256) with st.spinner('Analyzing...'): for patch_filename in os.listdir(PATCHES_DIR): if patch_filename.endswith(".png"): patch_path = os.path.join(PATCHES_DIR, patch_filename) patch_img = Image.open(patch_path) patch_tr_img = transforms(patch_img) prediction = predict(patch_tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 mask_filename = f"mask_{patch_filename}" mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename) Image.fromarray(mask).save(mask_filepath) merged_mask_filename = f"mask_{timestamp}.png" merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename) merge(PRED_PATCHES_DIR, merged_mask_filepath, img_array.shape) st.info('Cleaning up temporary files...') for dir in [PATCHES_DIR, PRED_PATCHES_DIR]: shutil.rmtree(dir) os.makedirs(dir) st.success('Temporary files cleaned up') else: tr_img = transforms(Image.open(os.path.join(UPLOAD_DIR, filename))) prediction = predict(tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 merged_mask_filename = f"mask_{timestamp}.png" merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename) Image.fromarray(mask).save(merged_mask_filepath) return merged_mask_filepath def upload_page(): if 'file_uploaded' not in st.session_state: st.session_state.file_uploaded = False image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) if image is not None: reset_state() timestamp = int(time.time()) filename, filepath = process_image(image, timestamp) img = Image.open(filepath) st.image(img, caption='Uploaded Image', use_column_width=True) st.success(f'Image saved as {filename}') st.session_state.filename = filename img_array = np.array(img) mask_filepath = predict_image(img_array, filename, timestamp) st.session_state.mask_filename = mask_filepath log_image_details(timestamp, filename, os.path.basename(mask_filepath)) st.session_state.file_uploaded = True if st.session_state.file_uploaded and st.button('View result'): if st.session_state.filename is None: st.error("Please upload an image before viewing the result.") else: st.success('Image analyzed') st.session_state.page = 'result' st.rerun() def result_page(): st.title('Analysis Result') if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: st.error("No image or mask file found. Please upload and process an image first.") if st.button('Back to Upload'): reset_state() st.rerun() return col1, col2 = st.columns(2) original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) mask_path = st.session_state.mask_filename if os.path.exists(original_img_path): original_img = Image.open(original_img_path) col1.image(original_img, caption='Original Image', use_column_width=True) else: col1.error(f"Original image file not found: {original_img_path}") if os.path.exists(mask_path): mask = Image.open(mask_path) col2.image(mask, caption='Predicted Mask', use_column_width=True) else: col2.error(f"Predicted mask file not found: {mask_path}") st.subheader("Overlay with Area of Buildings (sqft)") if os.path.exists(original_img_path) and os.path.exists(mask_path): original_np = cv2.imread(original_img_path) mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) if original_np.shape[:2] != mask_np.shape[:2]: mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') st.image(overlay_img, caption='Overlay Image', use_column_width=True) else: st.error("Image or mask file not found for overlay.") if st.button('Back to Upload'): reset_state() st.rerun() def main(): st.title('Building area estimation') if 'page' not in st.session_state: st.session_state.page = 'upload' if st.session_state.page == 'upload': upload_page() elif st.session_state.page == 'result': result_page() if __name__ == '__main__': main()