Building_area / app.py
Pavan2k4's picture
Update app.py
c928a9a verified
raw
history blame
12.4 kB
import streamlit as st
import sys
import os
import shutil
import time
from datetime import datetime
import csv
import cv2
import numpy as np
from PIL import Image
import torch
sys.path.append('Utils')
sys.path.append('model')
from model.CBAM.reunet_cbam import reunet_cbam
from model.transform import transforms
from model.unet import UNET
from Utils.area import pixel_to_sqft, process_and_overlay_image
from Utils.convert import read_pansharpened_rgb
@st.cache_resource
def load_model():
model = reunet_cbam()
model.load_state_dict(torch.load('latest.pth', map_location='cpu', weights_only = True)['model_state_dict'])
model.eval()
return model
# Load model
model = load_model()
def refine_mask(mask, blur_kernel=5, threshold_value=127, morph_kernel_size=3, min_object_size=100):
"""Refine and clean the mask with Gaussian blur, thresholding, morphological operations, and small object removal."""
# Ensure mask is grayscale
if len(mask.shape) > 2:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
# Apply Gaussian blur to smooth edges
mask = cv2.GaussianBlur(mask, (blur_kernel, blur_kernel), 0)
# Apply binary threshold
_, mask = cv2.threshold(mask, threshold_value, 255, cv2.THRESH_BINARY)
# Apply morphological operations (opening and closing)
kernel = np.ones((morph_kernel_size, morph_kernel_size), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
# Remove small objects based on area
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] < min_object_size:
mask[labels == i] = 0
return mask
# save to dir func
base = os.getcwd()
# Define subdirectories
UPLOAD_DIR = os.path.join(base,"Images")
MASK_DIR = os.path.join(base,"Masks")
CSV_LOG_PATH = "image_log.csv"
# Create directories with read and write permissions
for directory in [UPLOAD_DIR, MASK_DIR]:
os.makedirs(directory, exist_ok=True)
def predict(image):
with torch.no_grad():
output = model(image.unsqueeze(0))
return output.squeeze().cpu().numpy()
def split_image(image, patch_size=512):
h, w, _ = image.shape
patches = []
for y in range(0, h, patch_size):
for x in range(0, w, patch_size):
patch = image[y:min(y+patch_size, h), x:min(x+patch_size, w)]
patches.append((f"patch_{y}_{x}.png", patch))
return patches
def merge(patch_folder, dest_image='out.png', image_shape=None):
merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8)
for filename in os.listdir(patch_folder):
if filename.endswith(".png"):
patch_path = os.path.join(patch_folder, filename)
patch = cv2.imread(patch_path)
patch_height, patch_width, _ = patch.shape
# Extract patch coordinates from filename
parts = filename.split("_")
x, y = None, None
for part in parts:
if part.endswith(".png"):
x = int(part.split(".")[0])
elif part.isdigit():
y = int(part)
if x is None or y is None:
raise ValueError(f"Invalid filename: {filename}")
# Check if patch fits within image boundaries
if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]:
# Adjust patch position to fit within image boundaries
if x + patch_width > image_shape[1]:
x = image_shape[1] - patch_width
if y + patch_height > image_shape[0]:
y = image_shape[0] - patch_height
# Merge patch into the main image
merged[y:y+patch_height, x:x+patch_width, :] = patch
cv2.imwrite(dest_image, merged)
return merged
def process_large_image(model, image_path, patch_size=512):
# Read the image
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Failed to read image from {image_path}")
h, w, _ = img.shape
st.write(f"Processing image of size {w}x{h}")
# Split the image into patches
patches = split_image(img, patch_size)
# Process each patch
for filename, patch in patches:
patch_pil = Image.fromarray(cv2.cvtColor(patch, cv2.COLOR_BGR2RGB))
patch_transformed = transforms(patch_pil)
prediction = predict(patch_transformed)
mask = (prediction > 0.5).astype(np.uint8) * 255
# Save the mask patch
mask_filepath = os.path.join(PRED_PATCHES_DIR, filename)
cv2.imwrite(mask_filepath, mask)
# Merge the predicted patches
merged_mask = merge(PRED_PATCHES_DIR, dest_image='merged_mask.png', image_shape=img.shape)
return merged_mask
def log_image_details(image_id, image_filename, mask_filename):
file_exists = os.path.exists(CSV_LOG_PATH)
current_time = datetime.now()
date = current_time.strftime('%Y-%m-%d')
time = current_time.strftime('%H:%M:%S')
with open(CSV_LOG_PATH, mode='a', newline='') as file:
writer = csv.writer(file)
if not file_exists:
writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
# Get the next S.No
if file_exists:
with open(CSV_LOG_PATH, mode='r') as f:
reader = csv.reader(f)
sno = sum(1 for row in reader)
else:
sno = 1
writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
def upload_page():
if 'file_uploaded' not in st.session_state:
st.session_state.file_uploaded = False
if 'filename' not in st.session_state:
st.session_state.filename = None
if 'mask_filename' not in st.session_state:
st.session_state.mask_filename = None
image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
if image is not None and not st.session_state.file_uploaded:
try:
bytes_data = image.getvalue()
timestamp = int(time.time())
original_filename = image.name
file_extension = os.path.splitext(original_filename)[1].lower()
if file_extension in ['.tiff', '.tif']:
filename = f"image_{timestamp}.tif"
converted_filename = f"image_{timestamp}_converted.png"
else:
filename = f"image_{timestamp}.png"
converted_filename = filename
filepath = os.path.join(UPLOAD_DIR, filename)
converted_filepath = os.path.join(UPLOAD_DIR, converted_filename)
with open(filepath, "wb") as f:
f.write(bytes_data)
if file_extension in ['.tiff', '.tif']:
st.info('Processing GeoTIFF image...')
rgb_image = read_pansharpened_rgb(filepath)
cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}')
img = Image.open(converted_filepath)
else:
img = Image.open(filepath)
img.save(converted_filepath)
if os.path.exists(converted_filepath):
st.success(f"Image saved successfully: {converted_filepath}")
file_size = os.path.getsize(converted_filepath)
st.write(f"File size: {file_size} bytes")
else:
st.error(f"Failed to save image: {converted_filepath}")
st.image(img, caption='Uploaded Image', use_column_width=True)
st.success(f'Image processed and saved as {converted_filename}')
st.session_state.filename = converted_filename
img_array = np.array(img)
if img_array.shape[0] > 650 or img_array.shape[1] > 650:
st.info('Large image detected. Using patch-based processing.')
with st.spinner('Analyzing large image...'):
full_mask = process_large_image(model, converted_filepath)
else:
st.info('Small image detected. Processing whole image at once.')
with st.spinner('Analyzing image...'):
img_transformed = transforms(img)
prediction = predict(img_transformed)
full_mask = (prediction > 0.5).astype(np.uint8) * 255
full_mask = refine_mask(full_mask)#-----------------------------------------------------------------------
mask_filename = f"mask_{timestamp}.png"
mask_filepath = os.path.join(MASK_DIR, mask_filename)
cv2.imwrite(mask_filepath, full_mask)
st.session_state.mask_filename = mask_filename
log_image_details(timestamp, converted_filename, mask_filename)
st.session_state.file_uploaded = True
st.success("Image processed successfully")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.error("Please check the logs for more details.")
print(f"Error in upload_page: {str(e)}")
if st.session_state.file_uploaded and st.button('View result'):
if st.session_state.filename is None:
st.error("Please upload an image before viewing the result.")
else:
st.success('Image analyzed')
st.session_state.page = 'result'
st.rerun()
def result_page():
st.title('Analysis Result')
if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
st.error("No image or mask file found. Please upload and process an image first.")
if st.button('Back to Upload'):
st.session_state.page = 'upload'
st.session_state.file_uploaded = False
st.session_state.filename = None
st.session_state.mask_filename = None
st.rerun()
return
col1, col2 = st.columns(2)
# Display original image
original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
if os.path.exists(original_img_path):
original_img = Image.open(original_img_path)
col1.image(original_img, caption='Original Image', use_column_width=True)
else:
col1.error(f"Original image file not found: {original_img_path}")
# Display predicted mask
mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename)
if os.path.exists(mask_path):
mask = Image.open(mask_path)
col2.image(mask, caption='Predicted Mask', use_column_width=True)
else:
col2.error(f"Predicted mask file not found: {mask_path}")
st.subheader("Overlay with Area of Buildings (sqft)")
# Display overlayed image
if os.path.exists(original_img_path) and os.path.exists(mask_path):
original_np = cv2.imread(original_img_path)
mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# Ensure mask is binary
_, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
# Resize mask to match original image size if necessary
if original_np.shape[:2] != mask_np.shape[:2]:
mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
# Process and overlay image
overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
st.image(overlay_img, caption='Overlay Image', use_column_width=True)
else:
st.error("Image or mask file not found for overlay.")
if st.button('Back to Upload'):
st.session_state.page = 'upload'
st.session_state.file_uploaded = False
st.session_state.filename = None
st.session_state.mask_filename = None
st.rerun()
def main():
st.title('Building area estimation')
if 'page' not in st.session_state:
st.session_state.page = 'upload'
if st.session_state.page == 'upload':
upload_page()
elif st.session_state.page == 'result':
result_page()
if __name__ == '__main__':
main()