File size: 1,743 Bytes
b847bc7
7500064
 
b847bc7
 
581a214
b847bc7
581a214
b847bc7
b10122d
b847bc7
f48d218
b847bc7
b10122d
581a214
 
b847bc7
 
 
 
 
 
 
581a214
 
 
b847bc7
581a214
b847bc7
 
 
 
581a214
b847bc7
581a214
b847bc7
 
 
581a214
b847bc7
 
581a214
b847bc7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
import torch
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor

@st.cache_resource
def load_models():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load SAM (vit_b)
    sam_checkpoint = "sam_vit_b_01ec64.pth"
    model_type = "vit_b"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
    predictor = SamPredictor(sam)

    # Load MiDaS
    midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device)
    midas.eval()
    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
    transform = midas_transforms.dpt_transform

    return predictor, midas, transform

predictor, midas_model, midas_transform = load_models()

st.title("SAM + MiDaS Depth App")

uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file:
    image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    st.image(image_rgb, caption="Original Image", use_column_width=True)

    # Ask for click input
    st.write("Click a point for segmentation")
    coords = st.image(image_rgb, use_column_width=True)

    # For now, run depth estimation directly
    input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        depth = midas_model(input_tensor.unsqueeze(0))
        depth = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=image_rgb.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze().cpu().numpy()

    st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True)