sam_object / app.py
lakshmi082024's picture
Update app.py
b847bc7 verified
import streamlit as st
import torch
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
@st.cache_resource
def load_models():
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load SAM (vit_b)
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
# Load MiDaS
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
return predictor, midas, transform
predictor, midas_model, midas_transform = load_models()
st.title("SAM + MiDaS Depth App")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file:
image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image_rgb, caption="Original Image", use_column_width=True)
# Ask for click input
st.write("Click a point for segmentation")
coords = st.image(image_rgb, use_column_width=True)
# For now, run depth estimation directly
input_tensor = midas_transform(image_rgb).to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
depth = midas_model(input_tensor.unsqueeze(0))
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=image_rgb.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze().cpu().numpy()
st.image(depth, caption="Estimated Depth", use_column_width=True, clamp=True)