lakshmi082024 commited on
Commit
b847bc7
·
verified ·
1 Parent(s): b10122d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -86
app.py CHANGED
@@ -1,103 +1,51 @@
1
- import gradio as gr
2
  import torch
3
- import numpy as np
4
  import cv2
5
- from PIL import Image
6
- import pandas as pd
7
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
8
- from segment_anything import SamPredictor, sam_model_registry
9
- import os
10
 
11
- # Load SAM and MiDaS models
12
  def load_models():
 
13
 
14
-
15
  sam_checkpoint = "sam_vit_b_01ec64.pth"
16
- model_type = "vit_b" # <-- Must match checkpoint
17
-
18
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
19
-
20
- if not os.path.exists(sam_checkpoint):
21
- raise FileNotFoundError("Please upload the SAM checkpoint file to the working directory.")
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint).to(device)
25
  predictor = SamPredictor(sam)
26
 
27
- midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
28
- midas.eval().to(device)
29
- midas_transform = Compose([
30
- Resize(384),
31
- ToTensor(),
32
- Normalize(mean=[0.5]*3, std=[0.5]*3)
33
- ])
34
- return predictor, midas, midas_transform
35
 
36
  predictor, midas_model, midas_transform = load_models()
37
 
38
- # Processing function
39
- def process_image(image_pil):
40
- image_np = np.array(image_pil)
41
- img_h, img_w = image_np.shape[:2]
42
 
43
- # Real-world reference dimensions (adjust as needed)
44
- real_image_width_cm = 100
45
- real_image_height_cm = 75
46
- assumed_max_depth_cm = 100
47
 
48
- pixel_to_cm_x = real_image_width_cm / img_w
49
- pixel_to_cm_y = real_image_height_cm / img_h
50
 
51
- # SAM segmentation
52
- predictor.set_image(image_np)
53
- masks, _, _ = predictor.predict(multimask_output=False)
54
 
55
- # MiDaS depth estimation
56
- input_tensor = midas_transform(image_pil).unsqueeze(0).to(next(midas_model.parameters()).device)
57
  with torch.no_grad():
58
- depth_prediction = midas_model(input_tensor).squeeze().cpu().numpy()
59
- depth_resized = cv2.resize(depth_prediction, (img_w, img_h))
60
-
61
- # Object volume computation
62
- volume_data = []
63
- for i, mask in enumerate(masks):
64
- x, y, w, h = cv2.boundingRect(mask.astype(np.uint8))
65
- width_px = w
66
- height_px = h
67
- width_cm = width_px * pixel_to_cm_x
68
- height_cm = height_px * pixel_to_cm_y
69
-
70
- depth_masked = depth_resized[mask > 0.5]
71
- if depth_masked.size == 0:
72
- continue
73
-
74
- normalized_depth = (depth_masked - np.min(depth_resized)) / (np.max(depth_resized) - np.min(depth_resized) + 1e-6)
75
- depth_cm = np.mean(normalized_depth) * assumed_max_depth_cm
76
- volume_cm3 = round(depth_cm * width_cm * height_cm, 2)
77
-
78
- volume_data.append([
79
- f"Object #{i+1}",
80
- round(depth_cm, 2),
81
- round(width_cm, 2),
82
- round(height_cm, 2),
83
- volume_cm3
84
- ])
85
-
86
- if not volume_data:
87
- return image_pil, "No objects segmented."
88
-
89
- df = pd.DataFrame(volume_data, columns=["Object", "Length (Depth) cm", "Breadth (Width) cm", "Height cm", "Volume cm³"])
90
- return image_pil, df
91
-
92
- # Gradio Interface
93
- with gr.Blocks() as demo:
94
- gr.Markdown("# 📦 Volume Estimation using SAM + MiDaS")
95
- with gr.Row():
96
- image_input = gr.Image(type="pil", label="Upload Image")
97
- run_btn = gr.Button("Estimate Volume")
98
- with gr.Row():
99
- output_image = gr.Image(label="Original Image")
100
- volume_table = gr.Dataframe(headers=["Object", "Length (Depth) cm", "Breadth (Width) cm", "Height cm", "Volume cm³"])
101
- run_btn.click(fn=process_image, inputs=image_input, outputs=[output_image, volume_table])
102
-
103
- demo.launch()
 
1
+ import streamlit as st
2
  import torch
 
3
  import cv2
4
+ import numpy as np
5
+ from segment_anything import sam_model_registry, SamPredictor
 
 
 
6
 
7
+ @st.cache_resource
8
  def load_models():
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load SAM (vit_b)
12
  sam_checkpoint = "sam_vit_b_01ec64.pth"
13
+ model_type = "vit_b"
 
14
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
 
 
 
 
 
 
15
  predictor = SamPredictor(sam)
16
 
17
+ # Load MiDaS
18
+ midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device)
19
+ midas.eval()
20
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
21
+ transform = midas_transforms.dpt_transform
22
+
23
+ return predictor, midas, transform
 
24
 
25
  predictor, midas_model, midas_transform = load_models()
26
 
27
+ st.title("SAM + MiDaS Depth App")
 
 
 
28
 
29
+ uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
30
+ if uploaded_file:
31
+ image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1)
32
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
33
 
34
+ st.image(image_rgb, caption="Original Image", use_column_width=True)
 
35
 
36
+ # Ask for click input
37
+ st.write("Click a point for segmentation")
38
+ coords = st.image(image_rgb, use_column_width=True)
39
 
40
+ # For now, run depth estimation directly
41
+ input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu")
42
  with torch.no_grad():
43
+ depth = midas_model(input_tensor.unsqueeze(0))
44
+ depth = torch.nn.functional.interpolate(
45
+ depth.unsqueeze(1),
46
+ size=image_rgb.shape[:2],
47
+ mode="bicubic",
48
+ align_corners=False,
49
+ ).squeeze().cpu().numpy()
50
+
51
+ st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True)