File size: 13,624 Bytes
2c63a77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import streamlit as st
import torch
import torchvision.transforms as T
import numpy as np
import cv2
from PIL import Image, UnidentifiedImageError
import os
import subprocess # For Detectron2 installation check
import sys # For Detectron2 installation check

# --- Detectron2 Imports (handle potential import errors) ---
d2_imported_successfully = False
try:
    import detectron2
    from detectron2.engine import DefaultPredictor
    from detectron2.config import get_cfg
    from detectron2 import model_zoo
    from detectron2.utils.visualizer import Visualizer, ColorMode
    from detectron2.data import MetadataCatalog
    from detectron2.structures import Boxes # For Bounding Boxes
    d2_imported_successfully = True
    print("Detectron2 utilities imported successfully.")
except ImportError:
    st.error("Detectron2 not found or not installed correctly. Please ensure it's installed in your environment.")
    print("❌ Failed to import Detectron2 utilities.")
except Exception as e:
    st.error(f"An error occurred during Detectron2 imports: {e}")
    print(f"❌ An error occurred during Detectron2 imports: {e}")


# --- PyTorch Model Imports ---
from torchvision import models as torchvision_models
import torch.nn as nn

# ------------------------------
# Configuration
# ------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CNN_INPUT_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Ensure this path is correct for your environment
MODEL_PATH = r"pix3d_dimension_estimator_mask_crop.pth"
OUTPUT_DIR = 'streamlit_d2_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)


# ------------------------------
# Dimension Estimation CNN
# ------------------------------
def create_dimension_estimator_cnn_for_inference(num_outputs=4):
    model = torchvision_models.resnet50(weights=None) # Load architecture only
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_outputs) # Outputs L, W, H, V
    )
    return model

@st.cache_resource
def load_dimension_model():
    model = None
    if not os.path.exists(MODEL_PATH):
        st.error(f"Dimension estimation model not found at {MODEL_PATH}. Please check the path.")
        return None
    try:
        model = create_dimension_estimator_cnn_for_inference()
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        model.to(DEVICE)
        model.eval()
        print(f"Dimension estimation model loaded from {MODEL_PATH}")
    except Exception as e:
        st.error(f"Error loading dimension estimation model: {e}")
        return None
    return model

# ------------------------------
# Detectron2 Model
# ------------------------------
@st.cache_resource
def load_detectron2_model():
    if not d2_imported_successfully:
        return None, None
    try:
        cfg = get_cfg()
        # Example: Mask R-CNN for instance segmentation
        cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Set threshold for detection
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        if not torch.cuda.is_available():
            cfg.MODEL.DEVICE = "cpu"
        else:
            cfg.MODEL.DEVICE = "cuda" # Explicitly set
        predictor = DefaultPredictor(cfg)
        print("Detectron2 predictor created.")
        return predictor, cfg
    except Exception as e:
        st.error(f"Error loading Detectron2 model: {e}")
        return None, None

# ------------------------------
# Helper Functions
# ------------------------------
def get_largest_instance_index(instances):
    """Returns the index of the largest instance based on mask area or box area."""
    if not len(instances):
        return -1 # No instances

    if instances.has("pred_masks"):
        areas = instances.pred_masks.sum(dim=(1,2)) # Sum of True pixels in boolean mask
        if len(areas) > 0:
            return areas.argmax().item()
    elif instances.has("pred_boxes"):
        boxes_tensor = instances.pred_boxes.tensor
        areas = (boxes_tensor[:, 2] - boxes_tensor[:, 0]) * (boxes_tensor[:, 3] - boxes_tensor[:, 1])
        if len(areas) > 0:
            return areas.argmax().item()
    return 0 # Default to first instance if area calculation fails or no masks/boxes

def crop_from_mask(image_np_rgb, mask_tensor):
    """Crops an object from an image using a boolean mask tensor."""
    mask_np = mask_tensor.cpu().numpy().astype(np.uint8) # Ensure mask is on CPU and uint8
    if mask_np.sum() == 0: return None # Empty mask

    rows = np.any(mask_np, axis=1)
    cols = np.any(mask_np, axis=0)
    if not np.any(rows) or not np.any(cols): return None

    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]

    # Add padding to the bounding box from mask
    padding = 5
    ymin = max(0, ymin - padding)
    xmin = max(0, xmin - padding)
    ymax = min(image_np_rgb.shape[0] - 1, ymax + padding)
    xmax = min(image_np_rgb.shape[1] - 1, xmax + padding)


    if ymin >= ymax or xmin >= xmax : return None
    cropped_image = image_np_rgb[ymin:ymax+1, xmin:xmax+1, :]
    return cropped_image

def predict_dimensions_cnn(image_patch_np_rgb, model):
    """Predicts dimensions using the custom CNN."""
    if model is None:
        return {"L": "N/A", "W": "N/A", "H": "N/A", "V": "N/A", "Note": "DimCNN not loaded"}
    try:
        if image_patch_np_rgb.dtype != np.uint8:
            image_patch_np_rgb = image_patch_np_rgb.astype(np.uint8)

        transform = T.Compose([
            T.ToPILImage(),
            T.Resize((CNN_INPUT_SIZE, CNN_INPUT_SIZE)),
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])
        input_tensor = transform(image_patch_np_rgb).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            pred = model(input_tensor)
        dims = pred.squeeze().cpu().tolist()

        if not isinstance(dims, list): dims = [dims]
        while len(dims) < 4: dims.append(0.0) # Pad if model outputs fewer

        # Assuming model was trained to output in meters, convert to cm for display
        L_cm = dims[0] * 100
        W_cm = dims[1] * 100
        H_cm = dims[2] * 100
        V_cm3 = dims[3] * 1_000_000 # Convert m^3 to cm^3

        return {
            "Length (cm)": f"{L_cm:.1f}",
            "Width (cm)": f"{W_cm:.1f}",
            "Height (cm)": f"{H_cm:.1f}",
            "Volume (cm³)": f"{V_cm3:.1f}",
            "Note": "CustomCNN (Pix3D Scale)"
        }
    except Exception as e:
        print(f"Error in predict_dimensions_cnn: {e}")
        return {"L": "N/A", "W": "N/A", "H": "N/A", "V": "N/A", "Note": "CNN Predict Error"}

# ------------------------------
# Streamlit UI
# ------------------------------
st.set_page_config(layout="wide", page_title="Object Dimension Estimator")
st.title("📦 Object Dimension & Volume Estimation")
st.write("Upload an image. The system will detect objects using Detectron2, draw bounding boxes and masks, and estimate dimensions for the largest detected object using a custom-trained CNN.")

# Load models
dim_model = load_dimension_model()
if d2_imported_successfully:
    d2_predictor, d2_cfg = load_detectron2_model()
    if d2_cfg is not None:
        # Attempt to get metadata, handle potential KeyErrors
        try:
            d2_metadata = MetadataCatalog.get(d2_cfg.DATASETS.TRAIN[0] if d2_cfg.DATASETS.TRAIN else "coco_2017_val")
        except KeyError:
            st.warning("Default COCO metadata not found. Trying 'coco_2017_train'. Class names might be generic if this also fails.")
            try:
                d2_metadata = MetadataCatalog.get("coco_2017_train")
            except KeyError:
                st.warning("Could not load standard COCO metadata. Using dummy metadata.")
                dummy_name = "streamlit_dummy_coco_dataset_main"
                if dummy_name not in MetadataCatalog.list():
                    MetadataCatalog.get(dummy_name).thing_classes = [f"class_{i}" for i in range(80)] # COCO has 80 classes
                d2_metadata = MetadataCatalog.get(dummy_name)
    else:
        d2_metadata = None # Set to None if cfg is None
else:
    d2_predictor = None
    d2_cfg = None
    d2_metadata = None


uploaded_file = st.file_uploader("Upload a single image (JPG/PNG)", accept_multiple_files=False, type=['jpg', 'jpeg', 'png'])

if uploaded_file:
    st.subheader(f"🖼️ Processing: {uploaded_file.name}")
    try:
        image_pil = Image.open(uploaded_file).convert("RGB")
        image_np_rgb = np.array(image_pil) # Convert PIL to NumPy RGB
        image_bgr = cv2.cvtColor(image_np_rgb, cv2.COLOR_RGB2BGR) # OpenCV BGR for Detectron2
    except UnidentifiedImageError:
        st.error("Cannot identify image file. Please upload a valid image.")
        image_bgr = None
    except Exception as e:
        st.error(f"Error loading image: {e}")
        image_bgr = None

    if image_bgr is not None:
        st.image(image_pil, caption="Uploaded Image", use_container_width=True)

        if d2_predictor is None or dim_model is None:
            st.error("One or more models (Detectron2, Dimension CNN) failed to load. Cannot process.")
        else:
            with st.spinner("Detecting objects and estimating dimensions..."):
                # --- Detectron2 Inference ---
                outputs = d2_predictor(image_bgr) # Detectron2 expects BGR
                instances = outputs["instances"].to("cpu")

                if len(instances) == 0:
                    st.warning("No objects detected by Detectron2.")
                else:
                    # --- Visualization with Bounding Boxes and Masks ---
                    # Create a copy for drawing Detectron2's full visualization
                    viz_image_bgr = image_bgr.copy()
                    v = Visualizer(viz_image_bgr[:, :, ::-1], metadata=d2_metadata, scale=0.8, instance_mode=ColorMode.IMAGE_BW)
                    out_vis = v.draw_instance_predictions(instances)
                    annotated_img_d2_rgb = out_vis.get_image()[:, :, ::-1] # Visualizer gives RGB

                    st.image(annotated_img_d2_rgb, caption="Detectron2 Detections (Masks & Boxes)", use_container_width=True)

                    # --- Process the largest detected instance for dimension estimation ---
                    largest_idx = get_largest_instance_index(instances)
                    if largest_idx != -1:
                        instance = instances[largest_idx]

                        class_name = "Unknown"
                        if instance.has("pred_classes") and d2_metadata and hasattr(d2_metadata, 'thing_classes'):
                            class_id = instance.pred_classes.item()
                            if class_id < len(d2_metadata.thing_classes):
                                class_name = d2_metadata.thing_classes[class_id]
                        score = instance.scores.item() if instance.has("scores") else 0.0
                        st.write(f"**Processing largest detected object:** {class_name} (Confidence: {score:.2f})")

                        # --- Crop from Mask for Custom CNN ---
                        if instance.has("pred_masks"):
                            mask_tensor = instance.pred_masks[0] # Get the mask for the largest instance
                            # Crop from the RGB numpy array
                            object_crop_rgb = crop_from_mask(image_np_rgb, mask_tensor)

                            if object_crop_rgb is not None and object_crop_rgb.shape[0] > 0 and object_crop_rgb.shape[1] > 0:
                                st.image(object_crop_rgb, caption="Cropped Object Patch for Dimension CNN", width=250) # Smaller display

                                # --- Predict Dimensions with Custom CNN ---
                                dims = predict_dimensions_cnn(object_crop_rgb, dim_model)
                                st.write("📏 **Predicted Dimensions (from Custom CNN):**")
                                st.json(dims)
                            else:
                                st.error("Could not crop a valid object patch from the mask.")
                        else:
                            st.warning("No segmentation mask found for the largest instance. Cannot estimate dimensions with custom CNN.")
                    else:
                        st.info("Could not determine the largest object to process for dimensions.")
    else:
        if uploaded_file: # If a file was uploaded but image_bgr is None
             st.error("Image could not be loaded for processing.")


# --- Status Footer ---
st.sidebar.markdown("---")
st.sidebar.subheader("ℹ️ System Status")
st.sidebar.markdown(f"**Processing Device:** `{DEVICE}`")
st.sidebar.markdown(f"**Detectron2 Predictor:** `{'Loaded' if d2_predictor else 'Not Loaded'}`")
st.sidebar.markdown(f"**Dimension CNN:** `{'Loaded' if dim_model else 'Not Loaded'}`")
if not os.path.exists(MODEL_PATH):
    st.sidebar.warning(f"Dimension CNN weights file not found at the specified path.")