File size: 20,639 Bytes
1924502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
import gradio as gr
import numpy as np
import cv2
import torch
import pathlib
import sys
import json
from PIL import Image
from PIL.ExifTags import TAGS
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Add the agent module to path
ROOT = pathlib.Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT / "goal2" / "src"))
from agent import models, geometry, io

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Camera presets for common devices
CAMERA_PRESETS = {
    "iPhone 12/13/14 (Main Camera)": {"fx": 1840, "fy": 1840, "description": "26mm equivalent, f/1.6"},
    "iPhone 12/13/14 (Ultra Wide)": {"fx": 920, "fy": 920, "description": "13mm equivalent, f/2.4"},
    "Samsung Galaxy S21/S22": {"fx": 1950, "fy": 1950, "description": "26mm equivalent"},
    "Google Pixel 6/7": {"fx": 1800, "fy": 1800, "description": "27mm equivalent"},
    "Generic Smartphone": {"fx": 1500, "fy": 1500, "description": "Typical smartphone camera"},
    "Custom": {"fx": 1500, "fy": 1500, "description": "Enter your own focal length values"}
}

class SizeEstimatorApp:
    def __init__(self):
        self.depth_net = None
        self.mask_gen = None
        self.current_image = None
        self.current_depth = None
        self.current_masks = None
        self.reference_object = None
        
    def detect_camera_from_exif(self, image_pil: Image.Image) -> Tuple[str, Dict]:
        """Try to detect camera type from EXIF data"""
        try:
            exif = image_pil._getexif()
            if not exif:
                return "Unknown", {}
            
            # Extract relevant EXIF data
            exif_data = {}
            for tag_id, value in exif.items():
                tag = TAGS.get(tag_id, tag_id)
                exif_data[tag] = value
            
            # Try to identify camera make/model
            make = exif_data.get('Make', '').lower()
            model = exif_data.get('Model', '').lower()
            
            # Match against known camera presets
            if 'apple' in make or 'iphone' in model:
                if any(x in model for x in ['12', '13', '14']):
                    return "iPhone 12/13/14 (Main Camera)", exif_data
                else:
                    return "Generic Smartphone", exif_data
            elif 'samsung' in make:
                return "Samsung Galaxy S21/S22", exif_data
            elif 'google' in make or 'pixel' in model:
                return "Google Pixel 6/7", exif_data
            else:
                return "Generic Smartphone", exif_data
                
        except Exception as e:
            print(f"EXIF detection failed: {e}")
            return "Unknown", {}
        
    def load_models(self):
        """Load the depth and segmentation models"""
        if self.depth_net is None:
            print("Loading Depth Anything V2...")
            self.depth_net = models.load_depth(DEVICE)
        if self.mask_gen is None:
            print("Loading SAM...")
            self.mask_gen = models.load_sam(DEVICE)
        return "βœ… Models loaded successfully!"
    
    def process_image(self, image: np.ndarray, camera_preset: str, fx_custom: float, fy_custom: float) -> Tuple[np.ndarray, str]:
        """Process uploaded image and generate depth + segmentation"""
        try:
            # Input validation
            if image is None:
                return None, "❌ No image provided. Please upload an image."
            
            if len(image.shape) != 3 or image.shape[2] != 3:
                return None, "❌ Invalid image format. Please upload a color image (RGB)."
            
            # Check image size constraints
            h, w = image.shape[:2]
            if h < 100 or w < 100:
                return None, "❌ Image too small. Please upload an image at least 100x100 pixels."
            
            if h > 4000 or w > 4000:
                status_msg = "⚠️ Large image detected. Resizing for processing...\n"
                # Resize very large images
                max_size = 2000
                scale = min(max_size/w, max_size/h)
                if scale < 1:
                    new_w, new_h = int(w * scale), int(h * scale)
                    image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
                    status_msg += f"πŸ“ Resized from {w}Γ—{h} to {new_w}Γ—{new_h}\n"
            else:
                status_msg = ""
            
            # Ensure models are loaded
            if self.depth_net is None or self.mask_gen is None:
                self.load_models()
            
            # Store the original image
            self.current_image = image.copy()
            
            # Validate camera parameters
            if camera_preset == "Custom":
                if fx_custom <= 0 or fy_custom <= 0:
                    return None, "❌ Invalid focal length values. Must be greater than 0."
                if fx_custom < 100 or fy_custom < 100 or fx_custom > 5000 or fy_custom > 5000:
                    return None, "❌ Focal length values seem unrealistic. Typical range: 100-5000 pixels."
                fx, fy = fx_custom, fy_custom
            else:
                preset = CAMERA_PRESETS[camera_preset]
                fx, fy = preset["fx"], preset["fy"]
            
            # Generate depth and masks using the robust approach
            depth, masks, processed_img = models.predict_depth_and_masks(
                self.depth_net, self.mask_gen, image, DEVICE, approach="aligned"
            )
            
            # Validate results
            if depth is None or len(depth.shape) != 2:
                return None, "❌ Failed to generate depth map. Please try a different image."
            
            if not masks or len(masks) == 0:
                return None, "❌ No objects detected in the image. Try an image with clearer objects."
            
            # Filter out very small masks (likely noise)
            min_area = (image.shape[0] * image.shape[1]) * 0.001  # 0.1% of image area
            filtered_masks = [m for m in masks if m['area'] > min_area]
            
            if len(filtered_masks) == 0:
                return None, "❌ No significant objects detected. Try an image with larger, clearer objects."
            
            self.current_depth = depth
            self.current_masks = filtered_masks
            
            # Create visualization
            vis_image = self.create_mask_visualization(processed_img, filtered_masks)
            
            status = status_msg + f"βœ… Processed successfully! Found {len(filtered_masks)} objects.\n"
            status += f"πŸ“· Camera: {camera_preset} (fx={fx:.0f}, fy={fy:.0f})\n"
            status += f"πŸ–ΌοΈ Image size: {image.shape[1]}Γ—{image.shape[0]}\n"
            if len(masks) > len(filtered_masks):
                status += f"πŸ” Filtered out {len(masks) - len(filtered_masks)} small objects\n"
            status += f"πŸ“ Ready for size estimation - select object number and known size below"
            
            return vis_image, status
            
        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            print("Full error:", error_details)  # For debugging
            return None, f"❌ Error processing image: {str(e)}\nPlease try a different image."
    
    def create_mask_visualization(self, image: np.ndarray, masks: List[Dict]) -> np.ndarray:
        """Create visualization with colored masks and labels"""
        vis_img = image.copy()
        
        # Sort masks by area (largest first)
        sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
        
        # Color each mask with different colors
        colors = plt.cm.Set3(np.linspace(0, 1, len(sorted_masks)))
        
        for i, mask_data in enumerate(sorted_masks):
            mask = mask_data['segmentation']
            color = colors[i][:3]  # RGB values
            
            # Apply colored overlay
            colored_mask = np.zeros_like(vis_img)
            colored_mask[mask] = [int(c * 255) for c in color]
            vis_img = cv2.addWeighted(vis_img, 0.7, colored_mask, 0.3, 0)
            
            # Add number label
            y, x = np.where(mask)
            if len(x) > 0 and len(y) > 0:
                center_x, center_y = int(np.mean(x)), int(np.mean(y))
                cv2.putText(vis_img, str(i+1), (center_x-10, center_y+5), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
                cv2.putText(vis_img, str(i+1), (center_x-10, center_y+5), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 1)
        
        return vis_img
    
    def select_reference_object(self, mask_number: int, reference_size_cm: float, dimension: str) -> str:
        """Select a mask as reference object and specify its known size"""
        try:
            if self.current_masks is None:
                return "❌ No image processed yet. Please upload and process an image first."
            
            if mask_number < 1 or mask_number > len(self.current_masks):
                return f"❌ Invalid mask number. Choose between 1 and {len(self.current_masks)}"
            
            if reference_size_cm <= 0:
                return "❌ Reference size must be greater than 0"
            
            # Get the selected mask (convert to 0-based index)
            sorted_masks = sorted(self.current_masks, key=lambda x: x['area'], reverse=True)
            selected_mask = sorted_masks[mask_number - 1]
            
            # Store reference object info
            self.reference_object = {
                'mask_data': selected_mask,
                'known_size_cm': reference_size_cm,
                'dimension': dimension  # 'width' or 'height'
            }
            
            return f"βœ… Reference object #{mask_number} selected!\nπŸ“ Known {dimension}: {reference_size_cm} cm"
            
        except Exception as e:
            return f"❌ Error selecting reference: {str(e)}"
    
    def calculate_all_sizes(self, camera_preset: str, fx_custom: float, fy_custom: float) -> str:
        """Calculate sizes of all objects using the reference object for scale"""
        try:
            if self.current_masks is None:
                return "❌ No image processed yet."
            
            if self.reference_object is None:
                return "❌ No reference object selected. Please select a reference object first."
            
            # Get camera parameters
            if camera_preset == "Custom":
                fx, fy = fx_custom, fy_custom
            else:
                preset = CAMERA_PRESETS[camera_preset]
                fx, fy = preset["fx"], preset["fy"]
            
            # Calculate reference object's pixel dimensions first
            ref_mask = self.reference_object['mask_data']['segmentation']
            ref_stats = geometry.pixel_to_metric(ref_mask, self.current_depth, fx, fy)
            
            # Get the reference object's measured dimension in pixels
            if self.reference_object['dimension'] == 'width':
                ref_pixel_size = ref_stats['width_m'] * 100  # Convert to cm
            else:  # height
                ref_pixel_size = ref_stats['height_m'] * 100  # Convert to cm
            
            # Calculate scale factor: known_size / measured_size
            scale_factor = self.reference_object['known_size_cm'] / ref_pixel_size
            
            # Calculate sizes for all objects
            results = []
            sorted_masks = sorted(self.current_masks, key=lambda x: x['area'], reverse=True)
            
            for i, mask_data in enumerate(sorted_masks):
                mask = mask_data['segmentation']
                stats = geometry.pixel_to_metric(mask, self.current_depth, fx, fy)
                
                # Apply scale correction
                corrected_width = stats['width_m'] * 100 * scale_factor  # cm
                corrected_height = stats['height_m'] * 100 * scale_factor  # cm
                corrected_distance = stats['distance_m'] * scale_factor  # meters
                
                # Check if this is the reference object by comparing mask data
                is_reference = np.array_equal(mask_data['segmentation'], self.reference_object['mask_data']['segmentation'])
                ref_marker = " (REFERENCE)" if is_reference else ""
                
                results.append(f"Object #{i+1}{ref_marker}:")
                results.append(f"  πŸ“ Width: {corrected_width:.1f} cm")
                results.append(f"  πŸ“ Height: {corrected_height:.1f} cm") 
                results.append(f"  πŸ“ Distance: {corrected_distance:.2f} m")
                results.append(f"  πŸ“ Area: {mask_data['area']} pixels")
                results.append("")
            
            # Find reference object number for display
            ref_object_num = None
            for i, mask_data in enumerate(sorted_masks):
                if np.array_equal(mask_data['segmentation'], self.reference_object['mask_data']['segmentation']):
                    ref_object_num = i + 1
                    break
            
            # Add calibration info
            results.append("=" * 40)
            results.append("πŸ“Š Calibration Info:")
            results.append(f"πŸ“· Camera: {camera_preset}")
            results.append(f"πŸ” Scale factor: {scale_factor:.3f}")
            results.append(f"πŸ“ Reference: Object #{ref_object_num if ref_object_num else 'Unknown'}")
            results.append(f"πŸ“ Known {self.reference_object['dimension']}: {self.reference_object['known_size_cm']} cm")
            
            return "\n".join(results)
            
        except Exception as e:
            return f"❌ Error calculating sizes: {str(e)}"

# Initialize the app
app = SizeEstimatorApp()

# Gradio interface
def create_interface():
    with gr.Blocks(title="πŸ“ Smart Object Size Estimator", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # πŸ“ Smart Object Size Estimator
        
        Upload an image and get real-world size measurements of objects using AI-powered depth estimation and segmentation.
        
        ## How to use:
        1. **Upload an image** and select your camera type
        2. **Click Process** to detect objects  
        3. **Select a reference object** by clicking its number and entering its known size
        4. **Calculate sizes** to get measurements of all objects
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Input section
                gr.Markdown("### πŸ“€ Input")
                image_input = gr.Image(type="numpy", label="Upload Image")
                
                # Camera settings
                gr.Markdown("### πŸ“· Camera Settings")
                camera_preset = gr.Dropdown(
                    choices=list(CAMERA_PRESETS.keys()),
                    value="iPhone 12/13/14 (Main Camera)",
                    label="Camera Type",
                    info="Select your camera or choose 'Custom' for manual input"
                )
                
                with gr.Row():
                    fx_custom = gr.Number(value=1500, label="Focal Length X (pixels)", visible=False)
                    fy_custom = gr.Number(value=1500, label="Focal Length Y (pixels)", visible=False)
                
                process_btn = gr.Button("πŸ”„ Process Image", variant="primary", size="lg")
                
                # Reference object selection
                gr.Markdown("### πŸ“ Reference Object")
                with gr.Row():
                    mask_number = gr.Number(value=1, label="Object Number", precision=0, minimum=1)
                    reference_size = gr.Number(value=10.0, label="Known Size (cm)", minimum=0.1)
                
                dimension_choice = gr.Radio(
                    choices=["width", "height"], 
                    value="width", 
                    label="Which dimension is the known size?"
                )
                
                select_ref_btn = gr.Button("πŸ“Œ Set as Reference", variant="secondary")
                calculate_btn = gr.Button("πŸ“Š Calculate All Sizes", variant="primary", size="lg")
            
            with gr.Column(scale=2):
                # Output section
                gr.Markdown("### πŸ–ΌοΈ Results")
                image_output = gr.Image(label="Detected Objects")
                status_output = gr.Textbox(label="Status", lines=4, max_lines=10)
                results_output = gr.Textbox(label="Size Measurements", lines=15, max_lines=25)
        
        # Event handlers
        def toggle_custom_focal(preset):
            if preset == "Custom":
                return gr.update(visible=True), gr.update(visible=True)
            else:
                return gr.update(visible=False), gr.update(visible=False)
        
        camera_preset.change(
            toggle_custom_focal,
            inputs=[camera_preset],
            outputs=[fx_custom, fy_custom]
        )
        
        # Load models on startup
        demo.load(app.load_models, outputs=[status_output])
        
        process_btn.click(
            app.process_image,
            inputs=[image_input, camera_preset, fx_custom, fy_custom],
            outputs=[image_output, status_output]
        )
        
        select_ref_btn.click(
            app.select_reference_object,
            inputs=[mask_number, reference_size, dimension_choice],
            outputs=[status_output]
        )
        
        calculate_btn.click(
            app.calculate_all_sizes,
            inputs=[camera_preset, fx_custom, fy_custom],
            outputs=[results_output]
        )
        
        # Additional controls and info
        with gr.Row():
            with gr.Column():
                gr.Markdown("### 🎯 Quick Actions")
                clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
                
            with gr.Column():
                gr.Markdown("### πŸ“Š Session Info")
                session_info = gr.Textbox(label="Current Session", value="No image processed", interactive=False)
        
        # Event handlers for additional features
        def clear_session():
            app.current_image = None
            app.current_depth = None
            app.current_masks = None
            app.reference_object = None
            return (
                None,  # image_output
                "πŸ—‘οΈ Session cleared. Upload a new image to start.",  # status_output
                "",  # results_output
                "No image processed"  # session_info
            )
        
        def update_session_info(camera_preset, fx_custom, fy_custom):
            if app.current_masks is None:
                return "No image processed"
            
            if camera_preset == "Custom":
                cam_info = f"Custom (fx={fx_custom:.0f}, fy={fy_custom:.0f})"
            else:
                cam_info = camera_preset
            
            ref_info = "None selected"
            if app.reference_object:
                ref_info = f"Object with {app.reference_object['known_size_cm']} cm {app.reference_object['dimension']}"
            
            return f"πŸ“· Camera: {cam_info}\nπŸ“ Reference: {ref_info}\n🎯 Objects: {len(app.current_masks)}"
        
        clear_btn.click(
            clear_session,
            outputs=[image_output, status_output, results_output, session_info]
        )
        
        # Update session info when things change
        for component in [camera_preset, fx_custom, fy_custom]:
            component.change(
                update_session_info,
                inputs=[camera_preset, fx_custom, fy_custom],
                outputs=[session_info]
            )
        
        gr.Markdown("""
        ### πŸ’‘ Tips for best results:
        - Use good lighting and avoid shadows
        - Ensure objects are clearly visible and separated
        - Choose a reference object you know the exact size of
        - For phones, try the camera-specific presets first
        - Custom focal lengths can be calibrated using camera calibration tools
        """)
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True, debug=True)