import shutil import streamlit as st import os import sys import pandas as pd import json from PIL import Image import logging sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.segmentation_model import SegmentationModel from models.identification_model import IdentificationModel from models.text_extraction_model import TextExtractionModel from models.summarization_model import SummarizationModel from utils.postprocessing import save_segmented_objects from utils.data_mapping import map_data, save_mapped_data from utils.visualization import visualize_detections, visualize_segmentation, create_summary_table # Set up logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') @st.cache_resource def load_segmentation_model(): return SegmentationModel() @st.cache_resource def load_identification_model(): return IdentificationModel() @st.cache_resource def load_text_extraction_model(): return TextExtractionModel() @st.cache_resource def load_summarization_model(): return SummarizationModel() def main(): st.set_page_config(layout="wide") st.markdown(""" """, unsafe_allow_html=True) def clear_segmented_objects_folder(folder_path): # Remove all files in the segmented_objects folder if os.path.exists(folder_path) and os.path.isdir(folder_path): for filename in os.listdir(folder_path): file_path = os.path.join(folder_path, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove the file elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: st.error(f'Failed to delete {file_path}. Reason: {e}') else: print(f"Folder '{folder_path}' does not exist, skipping the clearing step.") clear_segmented_objects_folder("data/segmented_objects") st.title("Image Processing Pipeline 🤖") # File upload uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) logging.debug(f"Uploaded file: {uploaded_file}") if uploaded_file is not None: # Save uploaded file input_path = os.path.join("data", "input_images", uploaded_file.name) with open(input_path, "wb") as f: f.write(uploaded_file.getbuffer()) logging.debug(f"File saved to: {input_path}") image = Image.open(input_path) # Segmentation segmentation_model = load_segmentation_model() masks, boxes, labels, class_name = segmentation_model.segment_image(input_path) logging.debug(f"Segmentation results: {len(masks)} masks, {len(boxes)} boxes, {len(labels)} labels") # Save segmented objects objects = save_segmented_objects(image, masks, boxes, "data/segmented_objects") logging.debug(f"Saved {len(objects)} segmented objects") # Object identification identification_model = load_identification_model() detections = [] for file in sorted(os.listdir("data/segmented_objects")): f = os.path.join("data/segmented_objects", file) obj_detections = identification_model.identify_objects(f, class_name) if obj_detections: # Only append if the object was identified class_name.remove(obj_detections[0]['description']) detections.extend(obj_detections) logging.debug(f"Detections: {len(detections)} objects identified") # Match detections to segmented objects object_descriptions = [] for obj, det in zip(objects, detections): if det: object_descriptions.append(f"This is a {det['description']} with confidence {det['probability']:.2f}") else: object_descriptions.append("Unidentified object") logging.debug(f"Object description: {detections}") output_dir = "data/output" if not os.path.exists(output_dir): os.makedirs(output_dir) # Save detections with open("data/output/detections.json", "w") as f: json.dump(detections, f) logging.debug("Detections saved to data/output/detections.json") # Text extraction text_extraction_model = load_text_extraction_model() extracted_texts = [text_extraction_model.extract_text(obj[1]) for obj in objects] logging.debug(f"Extracted texts: {extracted_texts}") # Summarization summarization_model = load_summarization_model() summaries = [summarization_model.summarize(f"{desc} {text}") for desc, text in zip(object_descriptions, extracted_texts)] logging.debug(f"Summaries: {summaries}") # Data mapping mapped_data = map_data(objects, detections, object_descriptions, extracted_texts, summaries) save_mapped_data(mapped_data, "data/output/mapped_data.json") # Visualization visualize_segmentation(image, masks, "data/output/segmented_image.png") visualize_detections(input_path, "data/output/detected_objects.png") create_summary_table(mapped_data, "data/output/summary_table.csv") # Load the images and table # Initialize session state if not already done if 'show_original_image' not in st.session_state: st.session_state.show_original_image = False if 'show_segmented_image' not in st.session_state: st.session_state.show_segmented_image = False if 'show_detected_objects' not in st.session_state: st.session_state.show_detected_objects = False if 'show_summary_table' not in st.session_state: st.session_state.show_summary_table = False button_col1, button_col2, button_col3, button_col4 = st.columns(4) with button_col1: if st.button("Show Original Image"): st.session_state.show_original_image = not st.session_state.show_original_image with button_col2: if st.button("Show Segmented Image"): st.session_state.show_segmented_image = not st.session_state.show_segmented_image with button_col3: if st.button("Show Detected Objects"): st.session_state.show_detected_objects = not st.session_state.show_detected_objects with button_col4: if st.button("Show Summary Table"): st.session_state.show_summary_table = not st.session_state.show_summary_table # Display components based on session state def resize_image(image_path, target_width, target_height): image = Image.open(image_path) resized_image = image.resize((target_width, target_height)) return resized_image # Set desired width and height IMAGE_WIDTH = 600 IMAGE_HEIGHT = 400 if st.session_state.show_original_image: col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) with col2: resized_image = resize_image(input_path, IMAGE_WIDTH, IMAGE_HEIGHT) st.image(resized_image, caption="Original Image", use_column_width=True) if st.session_state.show_segmented_image: col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) with col2: resized_image = resize_image("data/output/segmented_image.png", IMAGE_WIDTH, IMAGE_HEIGHT) st.image(resized_image, caption="Segmented Image", use_column_width=True) if st.session_state.show_detected_objects: col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) with col2: resized_image = resize_image("data/output/detected_objects.png", IMAGE_WIDTH, IMAGE_HEIGHT) st.image(resized_image, caption="Detected Objects", use_column_width=True) if st.session_state.show_summary_table: col1, col2, col3 = st.columns([1, 3, 1]) with col2: summary_table = pd.read_csv("data/output/summary_table.csv") st.table(summary_table) if __name__ == "__main__": main()