Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| import numpy as np | |
| from transformers import DPTForDepthEstimation, DPTImageProcessor | |
| import time | |
| import warnings | |
| import asyncio | |
| import json | |
| import websockets | |
| warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device) | |
| processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") | |
| cap = cv2.VideoCapture(0) | |
| def resize_image(image, target_size=(256, 256)): | |
| return cv2.resize(image, target_size) | |
| def manual_normalize(depth_map): | |
| min_val = np.min(depth_map) | |
| max_val = np.max(depth_map) | |
| if min_val != max_val: | |
| normalized = (depth_map - min_val) / (max_val - min_val) | |
| return (normalized * 255).astype(np.uint8) | |
| else: | |
| return np.zeros_like(depth_map, dtype=np.uint8) | |
| frame_skip = 4 | |
| color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO) | |
| connected = set() | |
| async def broadcast(message): | |
| for websocket in connected: | |
| try: | |
| await websocket.send(message) | |
| except websockets.exceptions.ConnectionClosed: | |
| connected.remove(websocket) | |
| async def handler(websocket, path): | |
| connected.add(websocket) | |
| try: | |
| await websocket.wait_closed() | |
| finally: | |
| connected.remove(websocket) | |
| async def process_frames(): | |
| frame_count = 0 | |
| prev_frame_time = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| if frame_count % frame_skip != 0: | |
| continue | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| resized_frame = resize_image(rgb_frame) | |
| inputs = processor(images=resized_frame, return_tensors="pt").to(device) | |
| inputs = {k: v.to(torch.float16) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_depth = outputs.predicted_depth | |
| depth_map = predicted_depth.squeeze().cpu().numpy() | |
| depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0) | |
| depth_map = depth_map.astype(np.float32) | |
| if depth_map.size == 0: | |
| depth_map = np.zeros((256, 256), dtype=np.uint8) | |
| else: | |
| if np.any(depth_map) and np.min(depth_map) != np.max(depth_map): | |
| depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) | |
| else: | |
| depth_map = np.zeros_like(depth_map, dtype=np.uint8) | |
| if np.all(depth_map == 0): | |
| depth_map = manual_normalize(depth_map) | |
| data = { | |
| 'depthMap': depth_map.tolist(), | |
| 'rgbFrame': rgb_frame.tolist() | |
| } | |
| await broadcast(json.dumps(data)) | |
| new_frame_time = time.time() | |
| fps = 1 / (new_frame_time - prev_frame_time) | |
| prev_frame_time = new_frame_time | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| async def main(): | |
| server = await websockets.serve(handler, "localhost", 8765) | |
| await asyncio.gather(server.wait_closed(), process_frames()) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |