Spaces:
Sleeping
Sleeping
import shutil | |
import streamlit as st | |
import os | |
import sys | |
import pandas as pd | |
import json | |
from PIL import Image | |
import logging | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from models.segmentation_model import SegmentationModel | |
from models.identification_model import IdentificationModel | |
from models.text_extraction_model import TextExtractionModel | |
from models.summarization_model import SummarizationModel | |
from utils.postprocessing import save_segmented_objects | |
from utils.data_mapping import map_data, save_mapped_data | |
from utils.visualization import visualize_detections, visualize_segmentation, create_summary_table | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
def load_segmentation_model(): | |
return SegmentationModel() | |
def load_identification_model(): | |
return IdentificationModel() | |
def load_text_extraction_model(): | |
return TextExtractionModel() | |
def load_summarization_model(): | |
return SummarizationModel() | |
def main(): | |
st.set_page_config(layout="wide") | |
st.markdown(""" | |
<style> | |
.stImage > div { | |
margin-left: auto; | |
margin-right: auto; | |
} | |
.stTable > div { | |
margin-left: auto; | |
margin-right: auto; | |
} | |
h1{ /* Title style */ | |
text-align: center; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def clear_segmented_objects_folder(folder_path): | |
# Remove all files in the segmented_objects folder | |
if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
for filename in os.listdir(folder_path): | |
file_path = os.path.join(folder_path, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) # Remove the file | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) # Remove the directory | |
except Exception as e: | |
st.error(f'Failed to delete {file_path}. Reason: {e}') | |
else: | |
print(f"Folder '{folder_path}' does not exist, skipping the clearing step.") | |
clear_segmented_objects_folder("data/segmented_objects") | |
st.title("Image Processing Pipeline 🤖") | |
# File upload | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
logging.debug(f"Uploaded file: {uploaded_file}") | |
if uploaded_file is not None: | |
# Save uploaded file | |
input_path = os.path.join("data", "input_images", uploaded_file.name) | |
with open(input_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
logging.debug(f"File saved to: {input_path}") | |
image = Image.open(input_path) | |
# Segmentation | |
segmentation_model = load_segmentation_model() | |
masks, boxes, labels, class_name = segmentation_model.segment_image(input_path) | |
logging.debug(f"Segmentation results: {len(masks)} masks, {len(boxes)} boxes, {len(labels)} labels") | |
# Save segmented objects | |
objects = save_segmented_objects(image, masks, boxes, "data/segmented_objects") | |
logging.debug(f"Saved {len(objects)} segmented objects") | |
# Object identification | |
identification_model = load_identification_model() | |
detections = [] | |
for file in sorted(os.listdir("data/segmented_objects")): | |
f = os.path.join("data/segmented_objects", file) | |
obj_detections = identification_model.identify_objects(f, class_name) | |
if obj_detections: # Only append if the object was identified | |
class_name.remove(obj_detections[0]['description']) | |
detections.extend(obj_detections) | |
logging.debug(f"Detections: {len(detections)} objects identified") | |
# Match detections to segmented objects | |
object_descriptions = [] | |
for obj, det in zip(objects, detections): | |
if det: | |
object_descriptions.append(f"This is a {det['description']} with confidence {det['probability']:.2f}") | |
else: | |
object_descriptions.append("Unidentified object") | |
logging.debug(f"Object description: {detections}") | |
output_dir = "data/output" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Save detections | |
with open("data/output/detections.json", "w") as f: | |
json.dump(detections, f) | |
logging.debug("Detections saved to data/output/detections.json") | |
# Text extraction | |
text_extraction_model = load_text_extraction_model() | |
extracted_texts = [text_extraction_model.extract_text(obj[1]) for obj in objects] | |
logging.debug(f"Extracted texts: {extracted_texts}") | |
# Summarization | |
summarization_model = load_summarization_model() | |
summaries = [summarization_model.summarize(f"{desc} {text}") for desc, text in zip(object_descriptions, extracted_texts)] | |
logging.debug(f"Summaries: {summaries}") | |
# Data mapping | |
mapped_data = map_data(objects, detections, object_descriptions, extracted_texts, summaries) | |
save_mapped_data(mapped_data, "data/output/mapped_data.json") | |
# Visualization | |
visualize_segmentation(image, masks, "data/output/segmented_image.png") | |
visualize_detections(input_path, "data/output/detected_objects.png") | |
create_summary_table(mapped_data, "data/output/summary_table.csv") | |
# Load the images and table | |
# Initialize session state if not already done | |
if 'show_original_image' not in st.session_state: | |
st.session_state.show_original_image = False | |
if 'show_segmented_image' not in st.session_state: | |
st.session_state.show_segmented_image = False | |
if 'show_detected_objects' not in st.session_state: | |
st.session_state.show_detected_objects = False | |
if 'show_summary_table' not in st.session_state: | |
st.session_state.show_summary_table = False | |
button_col1, button_col2, button_col3, button_col4 = st.columns(4) | |
with button_col1: | |
if st.button("Show Original Image"): | |
st.session_state.show_original_image = not st.session_state.show_original_image | |
with button_col2: | |
if st.button("Show Segmented Image"): | |
st.session_state.show_segmented_image = not st.session_state.show_segmented_image | |
with button_col3: | |
if st.button("Show Detected Objects"): | |
st.session_state.show_detected_objects = not st.session_state.show_detected_objects | |
with button_col4: | |
if st.button("Show Summary Table"): | |
st.session_state.show_summary_table = not st.session_state.show_summary_table | |
# Display components based on session state | |
def resize_image(image_path, target_width, target_height): | |
image = Image.open(image_path) | |
resized_image = image.resize((target_width, target_height)) | |
return resized_image | |
# Set desired width and height | |
IMAGE_WIDTH = 600 | |
IMAGE_HEIGHT = 400 | |
if st.session_state.show_original_image: | |
col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) | |
with col2: | |
resized_image = resize_image(input_path, IMAGE_WIDTH, IMAGE_HEIGHT) | |
st.image(resized_image, caption="Original Image", use_column_width=True) | |
if st.session_state.show_segmented_image: | |
col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) | |
with col2: | |
resized_image = resize_image("data/output/segmented_image.png", IMAGE_WIDTH, IMAGE_HEIGHT) | |
st.image(resized_image, caption="Segmented Image", use_column_width=True) | |
if st.session_state.show_detected_objects: | |
col1, col2, col3 = st.columns([0.3, 0.4, 0.3]) | |
with col2: | |
resized_image = resize_image("data/output/detected_objects.png", IMAGE_WIDTH, IMAGE_HEIGHT) | |
st.image(resized_image, caption="Detected Objects", use_column_width=True) | |
if st.session_state.show_summary_table: | |
col1, col2, col3 = st.columns([1, 3, 1]) | |
with col2: | |
summary_table = pd.read_csv("data/output/summary_table.csv") | |
st.table(summary_table) | |
if __name__ == "__main__": | |
main() | |