Spaces:
Runtime error
Runtime error
File size: 1,743 Bytes
b847bc7 7500064 b847bc7 581a214 b847bc7 581a214 b847bc7 b10122d b847bc7 f48d218 b847bc7 b10122d 581a214 b847bc7 581a214 b847bc7 581a214 b847bc7 581a214 b847bc7 581a214 b847bc7 581a214 b847bc7 581a214 b847bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import streamlit as st
import torch
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
@st.cache_resource
def load_models():
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load SAM (vit_b)
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
# Load MiDaS
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
return predictor, midas, transform
predictor, midas_model, midas_transform = load_models()
st.title("SAM + MiDaS Depth App")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file:
image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image_rgb, caption="Original Image", use_column_width=True)
# Ask for click input
st.write("Click a point for segmentation")
coords = st.image(image_rgb, use_column_width=True)
# For now, run depth estimation directly
input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
depth = midas_model(input_tensor.unsqueeze(0))
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=image_rgb.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze().cpu().numpy()
st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True)
|