todap commited on
Commit
26932f6
·
verified ·
1 Parent(s): 9cfa91c

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import streamlit as st
3
+ import os
4
+ import sys
5
+ import pandas as pd
6
+ import json
7
+ from PIL import Image
8
+ import logging
9
+
10
+
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ from models.segmentation_model import SegmentationModel
14
+ from models.identification_model import IdentificationModel
15
+ from models.text_extraction_model import TextExtractionModel
16
+ from models.summarization_model import SummarizationModel
17
+ from utils.postprocessing import save_segmented_objects
18
+ from utils.data_mapping import map_data, save_mapped_data
19
+ from utils.visualization import visualize_detections, visualize_segmentation, create_summary_table
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
23
+
24
+ @st.cache_resource
25
+ def load_segmentation_model():
26
+ return SegmentationModel()
27
+
28
+ @st.cache_resource
29
+ def load_identification_model():
30
+ return IdentificationModel()
31
+
32
+ @st.cache_resource
33
+ def load_text_extraction_model():
34
+ return TextExtractionModel()
35
+
36
+ @st.cache_resource
37
+ def load_summarization_model():
38
+ return SummarizationModel()
39
+
40
+ def main():
41
+ st.set_page_config(layout="wide")
42
+ st.markdown("""
43
+ <style>
44
+ .stImage > div {
45
+ margin-left: auto;
46
+ margin-right: auto;
47
+ }
48
+ .stTable > div {
49
+ margin-left: auto;
50
+ margin-right: auto;
51
+ }
52
+ h1{ /* Title style */
53
+ text-align: center;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ def clear_segmented_objects_folder(folder_path):
59
+ # Remove all files in the segmented_objects folder
60
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
61
+ for filename in os.listdir(folder_path):
62
+ file_path = os.path.join(folder_path, filename)
63
+ try:
64
+ if os.path.isfile(file_path) or os.path.islink(file_path):
65
+ os.unlink(file_path) # Remove the file
66
+ elif os.path.isdir(file_path):
67
+ shutil.rmtree(file_path) # Remove the directory
68
+ except Exception as e:
69
+ st.error(f'Failed to delete {file_path}. Reason: {e}')
70
+ else:
71
+ print(f"Folder '{folder_path}' does not exist, skipping the clearing step.")
72
+
73
+ clear_segmented_objects_folder("data/segmented_objects")
74
+
75
+ st.title("Image Processing Pipeline 🤖")
76
+
77
+ # File upload
78
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
79
+ logging.debug(f"Uploaded file: {uploaded_file}")
80
+
81
+ if uploaded_file is not None:
82
+ # Save uploaded file
83
+ input_path = os.path.join("data", "input_images", uploaded_file.name)
84
+ with open(input_path, "wb") as f:
85
+ f.write(uploaded_file.getbuffer())
86
+ logging.debug(f"File saved to: {input_path}")
87
+
88
+ image = Image.open(input_path)
89
+
90
+ # Segmentation
91
+ segmentation_model = load_segmentation_model()
92
+ masks, boxes, labels, class_name = segmentation_model.segment_image(input_path)
93
+ logging.debug(f"Segmentation results: {len(masks)} masks, {len(boxes)} boxes, {len(labels)} labels")
94
+
95
+ # Save segmented objects
96
+ objects = save_segmented_objects(image, masks, boxes, "data/segmented_objects")
97
+ logging.debug(f"Saved {len(objects)} segmented objects")
98
+
99
+ # Object identification
100
+ identification_model = load_identification_model()
101
+ detections = []
102
+ for file in sorted(os.listdir("data/segmented_objects")):
103
+ f = os.path.join("data/segmented_objects", file)
104
+ obj_detections = identification_model.identify_objects(f, class_name)
105
+ if obj_detections: # Only append if the object was identified
106
+ class_name.remove(obj_detections[0]['description'])
107
+ detections.extend(obj_detections)
108
+ logging.debug(f"Detections: {len(detections)} objects identified")
109
+
110
+ # Match detections to segmented objects
111
+ object_descriptions = []
112
+ for obj, det in zip(objects, detections):
113
+ if det:
114
+ object_descriptions.append(f"This is a {det['description']} with confidence {det['probability']:.2f}")
115
+ else:
116
+ object_descriptions.append("Unidentified object")
117
+ logging.debug(f"Object description: {detections}")
118
+
119
+ output_dir = "data/output"
120
+ if not os.path.exists(output_dir):
121
+ os.makedirs(output_dir)
122
+ # Save detections
123
+ with open("data/output/detections.json", "w") as f:
124
+ json.dump(detections, f)
125
+ logging.debug("Detections saved to data/output/detections.json")
126
+
127
+ # Text extraction
128
+ text_extraction_model = load_text_extraction_model()
129
+ extracted_texts = [text_extraction_model.extract_text(obj[1]) for obj in objects]
130
+ logging.debug(f"Extracted texts: {extracted_texts}")
131
+
132
+ # Summarization
133
+ summarization_model = load_summarization_model()
134
+ summaries = [summarization_model.summarize(f"{desc} {text}") for desc, text in zip(object_descriptions, extracted_texts)]
135
+ logging.debug(f"Summaries: {summaries}")
136
+
137
+ # Data mapping
138
+ mapped_data = map_data(objects, detections, object_descriptions, extracted_texts, summaries)
139
+ save_mapped_data(mapped_data, "data/output/mapped_data.json")
140
+
141
+ # Visualization
142
+ visualize_segmentation(image, masks, "data/output/segmented_image.png")
143
+ visualize_detections(input_path, "data/output/detected_objects.png")
144
+ create_summary_table(mapped_data, "data/output/summary_table.csv")
145
+
146
+ # Load the images and table
147
+
148
+ # Initialize session state if not already done
149
+ if 'show_original_image' not in st.session_state:
150
+ st.session_state.show_original_image = False
151
+ if 'show_segmented_image' not in st.session_state:
152
+ st.session_state.show_segmented_image = False
153
+ if 'show_detected_objects' not in st.session_state:
154
+ st.session_state.show_detected_objects = False
155
+ if 'show_summary_table' not in st.session_state:
156
+ st.session_state.show_summary_table = False
157
+
158
+ button_col1, button_col2, button_col3, button_col4 = st.columns(4)
159
+
160
+ with button_col1:
161
+ if st.button("Show Original Image"):
162
+ st.session_state.show_original_image = not st.session_state.show_original_image
163
+
164
+ with button_col2:
165
+ if st.button("Show Segmented Image"):
166
+ st.session_state.show_segmented_image = not st.session_state.show_segmented_image
167
+
168
+ with button_col3:
169
+ if st.button("Show Detected Objects"):
170
+ st.session_state.show_detected_objects = not st.session_state.show_detected_objects
171
+
172
+ with button_col4:
173
+ if st.button("Show Summary Table"):
174
+ st.session_state.show_summary_table = not st.session_state.show_summary_table
175
+
176
+ # Display components based on session state
177
+ def resize_image(image_path, target_width, target_height):
178
+ image = Image.open(image_path)
179
+ resized_image = image.resize((target_width, target_height))
180
+ return resized_image
181
+
182
+ # Set desired width and height
183
+ IMAGE_WIDTH = 600
184
+ IMAGE_HEIGHT = 400
185
+
186
+ if st.session_state.show_original_image:
187
+ col1, col2, col3 = st.columns([0.3, 0.4, 0.3])
188
+ with col2:
189
+ resized_image = resize_image(input_path, IMAGE_WIDTH, IMAGE_HEIGHT)
190
+ st.image(resized_image, caption="Original Image", use_column_width=True)
191
+
192
+ if st.session_state.show_segmented_image:
193
+ col1, col2, col3 = st.columns([0.3, 0.4, 0.3])
194
+ with col2:
195
+ resized_image = resize_image("data/output/segmented_image.png", IMAGE_WIDTH, IMAGE_HEIGHT)
196
+ st.image(resized_image, caption="Segmented Image", use_column_width=True)
197
+
198
+ if st.session_state.show_detected_objects:
199
+ col1, col2, col3 = st.columns([0.3, 0.4, 0.3])
200
+ with col2:
201
+ resized_image = resize_image("data/output/detected_objects.png", IMAGE_WIDTH, IMAGE_HEIGHT)
202
+ st.image(resized_image, caption="Detected Objects", use_column_width=True)
203
+
204
+ if st.session_state.show_summary_table:
205
+ col1, col2, col3 = st.columns([1, 3, 1])
206
+ with col2:
207
+ summary_table = pd.read_csv("data/output/summary_table.csv")
208
+ st.table(summary_table)
209
+
210
+ if __name__ == "__main__":
211
+ main()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ clip
4
+ easyocr
5
+ transformers
6
+ matplotlib
7
+ pandas
8
+ streamlit
9
+ Pillow
10
+ ultralytics
11
+ opencv-python-headless
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/data_mapping.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def map_data(objects,detections, descriptions, extracted_texts, summaries):
4
+ mapped_data = {}
5
+ for (obj_id, file_path, box),det, description, text, summary in zip(objects,detections, descriptions, extracted_texts, summaries):
6
+ mapped_data[obj_id] = {
7
+ "file_path": file_path,
8
+ "box": box,
9
+ "description": description,
10
+ "extracted_text": text,
11
+ "summary": summary
12
+ }
13
+
14
+ return mapped_data
15
+
16
+ def save_mapped_data(mapped_data, output_file):
17
+ with open(output_file, "w") as f:
18
+ json.dump(mapped_data, f, indent=2)
utils/postprocessing.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+
4
+ def save_segmented_objects(image, masks, boxes, output_dir):
5
+ os.makedirs(output_dir, exist_ok=True)
6
+ objects = []
7
+ for i, (mask, box) in enumerate(zip(masks, boxes)):
8
+ obj_image = image.crop(box)
9
+ file_path = os.path.join(output_dir, f"object_{i}.png")
10
+ obj_image.save(file_path)
11
+ objects.append((f"object_{i}", file_path, box.tolist()))
12
+ return objects
utils/visualization.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pandas as pd
3
+ import cv2
4
+ from ultralytics import YOLO
5
+ from PIL import Image
6
+ def visualize_detections(image_path, output_path):
7
+
8
+ model = YOLO('yolov8s.pt') # You can change this to other YOLOv8 models as needed
9
+ # Read the image
10
+ image = cv2.imread(image_path)
11
+
12
+ # Run YOLOv8 inference on the image
13
+ results = model(image)
14
+
15
+ # Process the results and draw bounding boxes
16
+ for result in results:
17
+ boxes = result.boxes.cpu().numpy()
18
+ for box in boxes:
19
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
20
+ confidence = float(box.conf[0])
21
+ class_id = int(box.cls[0])
22
+ class_name = model.names[class_id]
23
+
24
+ # Draw bounding box
25
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
26
+
27
+ # Prepare label
28
+ label = f"{class_name}"
29
+
30
+ # Get label size
31
+ (label_width, label_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
32
+
33
+ # Draw filled rectangle for label background
34
+ cv2.rectangle(image, (x1, y1 - label_height - baseline), (x1 + label_width, y1), (0, 255, 0), cv2.FILLED)
35
+
36
+ # Put label text
37
+ cv2.putText(image, label, (x1, y1 - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
38
+
39
+ # Save the output image
40
+ cv2.imwrite(output_path, image)
41
+
42
+ def visualize_segmentation(image, masks, output_file):
43
+ #plt.imshow(image)
44
+ for mask in masks:
45
+ plt.imshow(mask, alpha=0.5)
46
+ plt.axis('off')
47
+ plt.savefig(output_file,bbox_inches='tight', pad_inches=0)
48
+ plt.close()
49
+
50
+
51
+ def create_summary_table(mapped_data, output_file):
52
+ df = pd.DataFrame.from_dict(mapped_data, orient='index')
53
+ df.to_csv(output_file)