File size: 4,390 Bytes
8166792 0dd36da 8166792 661e202 8166792 2609a96 8adc978 2609a96 8adc978 661e202 8166792 661e202 8166792 8adc978 0dd36da 2609a96 0dd36da 2609a96 0dd36da 2609a96 0dd36da 2609a96 8166792 0dd36da 2609a96 8166792 2609a96 8166792 8adc978 8166792 661e202 8166792 2609a96 8166792 2609a96 8adc978 2609a96 8adc978 864e7db 2609a96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import cv2
import torch
import numpy as np
import time
from midas.model_loader import default_models, load_model
import os
import urllib.request
import spaces
MODEL_FILE_URL = {
"midas_v21_small_256" : "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt",
"dpt_hybrid_384" : "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt",
"dpt_large_384" : "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt",
"dpt_swin2_large_384" : "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
"dpt_beit_large_512" : "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
}
class MonocularDepthEstimator:
def __init__(self,
model_type="midas_v21_small_256",
model_weights_path="models/",
optimize=False,
side_by_side=False,
height=None,
square=False,
grayscale=False):
# Don't initialize any CUDA/GPU stuff here
self.model_type = model_type
self.model_weights_path = model_weights_path
self.is_optimize = optimize
self.is_square = square
self.is_grayscale = grayscale
self.height = height
self.side_by_side = side_by_side
self.model = None
self.transform = None
self.net_w = None
self.net_h = None
print("Initializing parameters...")
if not os.path.exists(model_weights_path+model_type+".pt"):
print("Model file not found. Downloading...")
urllib.request.urlretrieve(MODEL_FILE_URL[model_type], model_weights_path+model_type+".pt")
print("Model file downloaded successfully.")
@spaces.GPU
def load_model_if_needed(self):
if self.model is None:
print("Loading MiDaS model...")
self.model, self.transform, self.net_w, self.net_h = load_model(
'cuda',
self.model_weights_path + self.model_type + ".pt",
self.model_type,
self.is_optimize,
self.height,
self.is_square
)
print("Model loaded successfully")
@spaces.GPU
def predict(self, image, target_size):
self.load_model_if_needed()
img_tensor = torch.from_numpy(image).to('cuda').unsqueeze(0)
if self.is_optimize:
img_tensor = img_tensor.to(memory_format=torch.channels_last)
img_tensor = img_tensor.half()
with torch.no_grad():
prediction = self.model.forward(img_tensor)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=target_size[::-1],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
return prediction
def process_prediction(self, depth_map):
depth_min = depth_map.min()
depth_max = depth_map.max()
normalized_depth = 255 * (depth_map - depth_min) / (depth_max - depth_min)
grayscale_depthmap = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2)
depth_colormap = cv2.applyColorMap(np.uint8(grayscale_depthmap), cv2.COLORMAP_INFERNO)
return normalized_depth/255, depth_colormap/255
@spaces.GPU
def make_prediction(self, image):
try:
print("Starting depth estimation...")
image = image.copy()
original_image_rgb = np.flip(image, 2)
self.load_model_if_needed()
image_tranformed = self.transform({"image": original_image_rgb/255})["image"]
pred = self.predict(image_tranformed, target_size=original_image_rgb.shape[1::-1])
depthmap, depth_colormap = self.process_prediction(pred)
print("Depth estimation complete")
return depthmap, depth_colormap
except Exception as e:
print(f"Error in make_prediction: {str(e)}")
import traceback
print(traceback.format_exc())
raise
if __name__ == "__main__":
depth_estimator = MonocularDepthEstimator(model_type="dpt_hybrid_384")
depth_estimator.run("assets/videos/testvideo2.mp4")
|