Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sys | |
| import os | |
| import shutil | |
| import time | |
| from datetime import datetime | |
| import csv | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from huggingface_hub import HfApi | |
| # Adjust import paths as needed | |
| sys.path.append('Utils') | |
| sys.path.append('model') | |
| from model.CBAM.reunet_cbam import reunet_cbam | |
| from model.transform import transforms | |
| from model.unet import UNET | |
| from Utils.area import pixel_to_sqft, process_and_overlay_image | |
| from split_merge import split, merge | |
| from Utils.convert import read_pansharpened_rgb | |
| # Initialize Hugging Face API | |
| hf_api = HfApi() | |
| # Get the token from secrets | |
| HF_TOKEN = st.secrets.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| st.error("HF_TOKEN not found in secrets. Please set it in your Space's Configuration > Secrets.") | |
| st.stop() | |
| # Your Space ID (this should match exactly with your Hugging Face Space URL) | |
| REPO_ID = "Pavan2k4/Building_area" | |
| REPO_TYPE = "space" | |
| # Define base directory for Hugging Face Spaces | |
| BASE_DIR = "DATA/" | |
| # Define subdirectories | |
| UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images") | |
| MASK_DIR = os.path.join(BASE_DIR, "generated_masks") | |
| PATCHES_DIR = os.path.join(BASE_DIR, "patches") | |
| PRED_PATCHES_DIR = os.path.join(BASE_DIR, "pred_patches") | |
| CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv") | |
| def split(image, destination = PATCHES_DIR, patch_size = 256): | |
| img = cv2.imread(image) | |
| h,w,_ = img.shape | |
| for y in range(0, h, patch_size): | |
| for x in range(0, w, patch_size): | |
| patch = img[y:y+patch_size, x:x+patch_size] | |
| patch_filename = f"patch_{y}_{x}.png" | |
| patch_path = os.path.join(destination, patch_filename) | |
| cv2.imwrite(patch_path, patch) | |
| def merge(patch_folder , dest_image = 'out.png', image_shape = None): | |
| merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8) | |
| for filename in os.listdir(patch_folder): | |
| if filename.endswith(".png"): | |
| patch_path = os.path.join(patch_folder, filename) | |
| patch = cv2.imread(patch_path) | |
| patch_height, patch_width, _ = patch.shape | |
| # Extract patch coordinates from filename | |
| parts = filename.split("_") | |
| x, y = None, None | |
| for part in parts: | |
| if part.endswith(".png"): | |
| x = int(part.split(".")[0]) | |
| elif part.isdigit(): | |
| y = int(part) | |
| if x is None or y is None: | |
| raise ValueError(f"Invalid filename: {filename}") | |
| # Check if patch fits within image boundaries | |
| if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]: | |
| # Adjust patch position to fit within image boundaries | |
| if x + patch_width > image_shape[1]: | |
| x = image_shape[1] - patch_width | |
| if y + patch_height > image_shape[0]: | |
| y = image_shape[0] - patch_height | |
| # Merge patch into the main image | |
| merged[y:y+patch_height, x:x+patch_width, :] = patch | |
| cv2.imwrite(dest_image, merged) | |
| # Create directories | |
| for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]: | |
| os.makedirs(directory, exist_ok=True) | |
| # Load model | |
| def load_model(): | |
| model = reunet_cbam() | |
| model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def predict(image): | |
| with torch.no_grad(): | |
| output = model(image.unsqueeze(0)) | |
| return output.squeeze().cpu().numpy() | |
| def save_to_hf_repo(local_path, repo_path): | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=repo_path, | |
| repo_id=REPO_ID, | |
| repo_type=REPO_TYPE, | |
| token=HF_TOKEN | |
| ) | |
| st.success(f"File uploaded successfully to {repo_path}") | |
| except Exception as e: | |
| st.error(f"Error uploading file: {str(e)}") | |
| st.error("Detailed error information:") | |
| st.exception(e) | |
| def log_image_details(image_id, image_filename, mask_filename): | |
| file_exists = os.path.exists(CSV_LOG_PATH) | |
| current_time = datetime.now() | |
| date = current_time.strftime('%Y-%m-%d') | |
| time = current_time.strftime('%H:%M:%S') | |
| with open(CSV_LOG_PATH, mode='a', newline='') as file: | |
| writer = csv.writer(file) | |
| if not file_exists: | |
| writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) | |
| # Get the next S.No | |
| if file_exists: | |
| with open(CSV_LOG_PATH, mode='r') as f: | |
| reader = csv.reader(f) | |
| sno = sum(1 for row in reader) | |
| else: | |
| sno = 1 | |
| writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) | |
| # Save CSV to Hugging Face repo | |
| save_to_hf_repo(CSV_LOG_PATH, 'image_log.csv') | |
| def upload_page(): | |
| if 'file_uploaded' not in st.session_state: | |
| st.session_state.file_uploaded = False | |
| if 'filename' not in st.session_state: | |
| st.session_state.filename = None | |
| if 'mask_filename' not in st.session_state: | |
| st.session_state.mask_filename = None | |
| image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) | |
| if image is not None and not st.session_state.file_uploaded: | |
| try: | |
| bytes_data = image.getvalue() | |
| timestamp = int(time.time()) | |
| original_filename = image.name | |
| file_extension = os.path.splitext(original_filename)[1].lower() | |
| if file_extension in ['.tiff', '.tif']: | |
| filename = f"image_{timestamp}.tif" | |
| converted_filename = f"image_{timestamp}_converted.png" | |
| else: | |
| filename = f"image_{timestamp}.png" | |
| converted_filename = filename | |
| filepath = os.path.join(UPLOAD_DIR, filename) | |
| converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) | |
| with open(filepath, "wb") as f: | |
| f.write(bytes_data) | |
| st.success(f"Image saved to {filepath}") | |
| # Save image to Hugging Face repo | |
| save_to_hf_repo(filepath, f'uploaded_images/{filename}') | |
| # Check if the uploaded file is a GeoTIFF | |
| if file_extension in ['.tiff', '.tif']: | |
| st.info('Processing GeoTIFF image...') | |
| rgb_image = read_pansharpened_rgb(filepath) | |
| cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) | |
| st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') | |
| img = Image.open(converted_filepath) | |
| else: | |
| img = Image.open(filepath) | |
| st.image(img, caption='Uploaded Image', use_column_width=True) | |
| st.success(f'Image processed and saved as {converted_filename}') | |
| # Store the full path of the converted image | |
| st.session_state.filename = converted_filename | |
| # Convert image to numpy array | |
| img_array = np.array(img) | |
| # Check if image shape is more than 650x650 | |
| if img_array.shape[0] > 650 or img_array.shape[1] > 650: | |
| # Split image into patches | |
| split(converted_filepath, patch_size=512) | |
| # Display buffer while analyzing | |
| with st.spinner('Analyzing...'): | |
| # Predict on each patch | |
| for patch_filename in os.listdir(PATCHES_DIR): | |
| if patch_filename.endswith(".png"): | |
| patch_path = os.path.join(PATCHES_DIR, patch_filename) | |
| patch_img = Image.open(patch_path) | |
| patch_tr_img = transforms(patch_img) | |
| prediction = predict(patch_tr_img) | |
| mask = (prediction > 0.5).astype(np.uint8) * 255 | |
| mask_filename = f"mask_{patch_filename}" | |
| mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename) | |
| Image.fromarray(mask).save(mask_filepath) | |
| # Merge predicted patches | |
| merged_mask_filename = f"mask_{timestamp}.png" | |
| merged_mask_path = os.path.join(MASK_DIR, merged_mask_filename) | |
| merge(PRED_PATCHES_DIR, merged_mask_path, img_array.shape) | |
| # Save merged mask | |
| st.session_state.mask_filename = merged_mask_filename | |
| # Clean up temporary patch files | |
| st.info('Cleaning up temporary files...') | |
| shutil.rmtree(PATCHES_DIR) | |
| shutil.rmtree(PRED_PATCHES_DIR) | |
| os.makedirs(PATCHES_DIR) # Recreate empty folders | |
| os.makedirs(PRED_PATCHES_DIR) | |
| st.success('Temporary files cleaned up') | |
| else: | |
| # Predict on whole image | |
| st.session_state.tr_img = transforms(img) | |
| prediction = predict(st.session_state.tr_img) | |
| mask = (prediction > 0.5).astype(np.uint8) * 255 | |
| mask_filename = f"mask_{timestamp}.png" | |
| mask_filepath = os.path.join(MASK_DIR, mask_filename) | |
| Image.fromarray(mask).save(mask_filepath) | |
| st.session_state.mask_filename = mask_filename | |
| # Save mask to Hugging Face repo | |
| mask_filepath = os.path.join(MASK_DIR, st.session_state.mask_filename) | |
| save_to_hf_repo(mask_filepath, f'generated_masks/{st.session_state.mask_filename}') | |
| # Log image details | |
| log_image_details(timestamp, converted_filename, st.session_state.mask_filename) | |
| st.session_state.file_uploaded = True | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.error("Please check the logs for more details.") | |
| print(f"Error in upload_page: {str(e)}") # This will appear in the Streamlit logs | |
| if st.session_state.file_uploaded and st.button('View result'): | |
| if st.session_state.filename is None: | |
| st.error("Please upload an image before viewing the result.") | |
| else: | |
| st.success('Image analyzed') | |
| st.session_state.page = 'result' | |
| st.rerun() | |
| def result_page(): | |
| st.title('Analysis Result') | |
| if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: | |
| st.error("No image or mask file found. Please upload and process an image first.") | |
| if st.button('Back to Upload'): | |
| st.session_state.page = 'upload' | |
| st.session_state.file_uploaded = False | |
| st.session_state.filename = None | |
| st.session_state.mask_filename = None | |
| st.rerun() | |
| return | |
| col1, col2 = st.columns(2) | |
| # Display original image | |
| original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) | |
| if os.path.exists(original_img_path): | |
| original_img = Image.open(original_img_path) | |
| col1.image(original_img, caption='Original Image', use_column_width=True) | |
| else: | |
| col1.error(f"Original image file not found: {original_img_path}") | |
| # Display predicted mask | |
| mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename) | |
| if os.path.exists(mask_path): | |
| mask = Image.open(mask_path) | |
| col2.image(mask, caption='Predicted Mask', use_column_width=True) | |
| else: | |
| col2.error(f"Predicted mask file not found: {mask_path}") | |
| st.subheader("Overlay with Area of Buildings (sqft)") | |
| # Display overlayed image | |
| if os.path.exists(original_img_path) and os.path.exists(mask_path): | |
| original_np = cv2.imread(original_img_path) | |
| mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| # Ensure mask is binary | |
| _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) | |
| # Resize mask to match original image size if necessary | |
| if original_np.shape[:2] != mask_np.shape[:2]: | |
| mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) | |
| # Process and overlay image | |
| overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') | |
| st.image(overlay_img, caption='Overlay Image', use_column_width=True) | |
| else: | |
| st.error("Image or mask file not found for overlay.") | |
| if st.button('Back to Upload'): | |
| shutil.rmtree(PATCHES_DIR) | |
| shutil.rmtree(PRED_PATCHES_DIR) | |
| st.session_state.page = 'upload' | |
| st.session_state.file_uploaded = False | |
| st.session_state.filename = None | |
| st.session_state.mask_filename = None | |
| st.rerun() | |
| def main(): | |
| st.title('Building area estimation') | |
| if 'page' not in st.session_state: | |
| st.session_state.page = 'upload' | |
| if st.session_state.page == 'upload': | |
| upload_page() | |
| elif st.session_state.page == 'result': | |
| result_page() | |
| if __name__ == '__main__': | |
| main() |