ghost233lism commited on
Commit
343f44f
Β·
verified Β·
1 Parent(s): b1afa51

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -503
app.py CHANGED
@@ -1,504 +1,412 @@
1
- import gradio as gr
2
- import os
3
- import cv2
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from PIL import Image
8
- import tempfile
9
- import io
10
-
11
- from depth_anything.dpt import DepthAnything_AC
12
-
13
-
14
- def normalize_depth(disparity_tensor):
15
- """Standard normalization method to convert disparity to depth"""
16
- eps = 1e-6
17
- disparity_min = disparity_tensor.min()
18
- disparity_max = disparity_tensor.max()
19
- normalized_disparity = (disparity_tensor - disparity_min) / (disparity_max - disparity_min + eps)
20
- return normalized_disparity
21
-
22
-
23
- def is_video_file(filepath):
24
- """Check if the given file is a video file based on its extension"""
25
- if filepath is None:
26
- return False
27
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
28
- _, ext = os.path.splitext(filepath.lower())
29
- return ext in video_extensions
30
-
31
-
32
- def load_model(model_path='checkpoints/depth_anything_AC_vits.pth', encoder='vits'):
33
- """Load trained depth estimation model"""
34
- model_configs = {
35
- 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'version': 'v2'},
36
- 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768], 'version': 'v2'},
37
- 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'version': 'v2'}
38
- }
39
-
40
- model = DepthAnything_AC(model_configs[encoder])
41
-
42
- if os.path.exists(model_path):
43
- checkpoint = torch.load(model_path, map_location='cpu')
44
- model.load_state_dict(checkpoint, strict=False)
45
- else:
46
- print(f"Warning: Model file {model_path} not found")
47
-
48
- model.eval()
49
- if torch.cuda.is_available():
50
- model.cuda()
51
-
52
- return model
53
-
54
-
55
- def preprocess_image(image, target_size=518):
56
- """Preprocess input image (supports both PIL Image and numpy array)"""
57
- if isinstance(image, str):
58
- raw_image = cv2.imread(image)
59
- if raw_image is None:
60
- raise ValueError(f"Cannot read image: {image}")
61
- image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
62
- elif isinstance(image, Image.Image):
63
- image = np.array(image)
64
- image = image.astype(np.float32) / 255.0
65
- elif isinstance(image, np.ndarray):
66
- if image.dtype == np.uint8:
67
- image = image.astype(np.float32) / 255.0
68
- else:
69
- raise ValueError(f"Unsupported image type: {type(image)}")
70
-
71
- if len(image.shape) == 3 and image.shape[2] == 3:
72
- pass
73
- elif len(image.shape) == 3 and image.shape[2] == 4:
74
- image = image[:, :, :3]
75
-
76
- h, w = image.shape[:2]
77
- scale = target_size / min(h, w)
78
- new_h, new_w = int(h * scale), int(w * scale)
79
-
80
- new_h = ((new_h + 13) // 14) * 14
81
- new_w = ((new_w + 13) // 14) * 14
82
- image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
83
-
84
- mean = np.array([0.485, 0.456, 0.406])
85
- std = np.array([0.229, 0.224, 0.225])
86
- image = (image - mean) / std
87
-
88
- image = torch.from_numpy(image.transpose(2, 0, 1)).float()
89
- image = image.unsqueeze(0)
90
-
91
- return image, (h, w)
92
-
93
-
94
- def postprocess_depth(depth_tensor, original_size):
95
- """Post-process depth map"""
96
- if depth_tensor.dim() == 3:
97
- depth_tensor = depth_tensor.unsqueeze(1)
98
- elif depth_tensor.dim() == 2:
99
- depth_tensor = depth_tensor.unsqueeze(0).unsqueeze(1)
100
-
101
- h, w = original_size
102
- depth = F.interpolate(depth_tensor, size=(h, w), mode='bilinear', align_corners=True)
103
- depth = depth.squeeze().cpu().numpy()
104
-
105
- return depth
106
-
107
-
108
- def create_colored_depth_map(depth, colormap='spectral'):
109
- """Create colored depth map"""
110
- if colormap == 'inferno':
111
- depth_colored = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
112
- depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
113
- elif colormap == 'spectral':
114
- from matplotlib import cm
115
- spectral_cmap = cm.get_cmap('Spectral_r')
116
- depth_colored = (spectral_cmap(depth) * 255).astype(np.uint8)
117
- depth_colored = depth_colored[:, :, :3]
118
- else:
119
- depth_colored = (depth * 255).astype(np.uint8)
120
- depth_colored = np.stack([depth_colored] * 3, axis=2)
121
-
122
- return depth_colored
123
-
124
-
125
- def process_video(video_path, colormap_choice, progress=gr.Progress()):
126
- """Process video file for depth estimation"""
127
- try:
128
- print(f"Processing video: {video_path}")
129
-
130
- cap = cv2.VideoCapture(video_path)
131
- if not cap.isOpened():
132
- raise ValueError(f"Cannot open video file: {video_path}")
133
-
134
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
135
- input_fps = cap.get(cv2.CAP_PROP_FPS)
136
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
137
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
-
139
- print(f"Video properties: {total_frames} frames, {input_fps} FPS, {width}x{height}")
140
-
141
- temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
142
- output_path = temp_output.name
143
- temp_output.close()
144
-
145
- fourcc = cv2.VideoWriter.fourcc(*'mp4v')
146
- out = cv2.VideoWriter(output_path, fourcc, input_fps, (width, height))
147
-
148
- if not out.isOpened():
149
- cap.release()
150
- raise ValueError("Cannot create output video file")
151
-
152
- frame_count = 0
153
-
154
- try:
155
- while True:
156
- ret, frame = cap.read()
157
- if not ret:
158
- break
159
-
160
- frame_count += 1
161
-
162
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
163
-
164
- try:
165
- image_tensor, original_size = preprocess_image(frame_rgb)
166
-
167
- if torch.cuda.is_available():
168
- image_tensor = image_tensor.cuda()
169
-
170
- with torch.no_grad():
171
- prediction = model(image_tensor)
172
- disparity_tensor = prediction['out']
173
- depth_tensor = normalize_depth(disparity_tensor)
174
-
175
- depth = postprocess_depth(depth_tensor, original_size)
176
-
177
- if depth is None:
178
- if depth_tensor.dim() == 1:
179
- h, w = original_size
180
- expected_size = h * w
181
- if depth_tensor.shape[0] == expected_size:
182
- depth_tensor = depth_tensor.view(1, 1, h, w)
183
- else:
184
- import math
185
- side_length = int(math.sqrt(depth_tensor.shape[0]))
186
- if side_length * side_length == depth_tensor.shape[0]:
187
- depth_tensor = depth_tensor.view(1, 1, side_length, side_length)
188
- depth = postprocess_depth(depth_tensor, original_size)
189
-
190
- if depth is None:
191
- print(f"Warning: Frame {frame_count} processing failed, using black frame")
192
- depth_frame = np.zeros((height, width, 3), dtype=np.uint8)
193
- else:
194
- if colormap_choice.lower() == 'inferno':
195
- depth_frame = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
196
- elif colormap_choice.lower() == 'spectral':
197
- from matplotlib import cm
198
- spectral_cmap = cm.get_cmap('Spectral_r')
199
- depth_frame = (spectral_cmap(depth) * 255).astype(np.uint8)
200
- depth_frame = depth_frame[:, :, :3]
201
- depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2BGR)
202
- else:
203
- depth_frame = (depth * 255).astype(np.uint8)
204
- depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_GRAY2BGR)
205
-
206
- out.write(depth_frame)
207
-
208
- except Exception as e:
209
- print(f"Error processing frame {frame_count}: {str(e)}")
210
- black_frame = np.zeros((height, width, 3), dtype=np.uint8)
211
- out.write(black_frame)
212
-
213
- progress((frame_count / total_frames), f"Processing progress: {frame_count}/{total_frames} frames")
214
-
215
- except Exception as e:
216
- print(f"Unexpected error during video processing: {str(e)}")
217
- finally:
218
- cap.release()
219
- out.release()
220
-
221
- print(f"Video processing completed! Output saved to: {output_path}")
222
- return output_path
223
-
224
- except Exception as e:
225
- print(f"Video processing failed: {str(e)}")
226
- return None
227
-
228
-
229
- print("Loading model...")
230
- model = load_model()
231
- print("Model loaded successfully!")
232
-
233
-
234
- def predict_depth(input_file, colormap_choice):
235
- """Main depth prediction function for both images and videos"""
236
- try:
237
- if input_file is None:
238
- return None, gr.update(visible=False)
239
-
240
- if is_video_file(input_file):
241
- output_path = process_video(input_file, colormap_choice)
242
- if output_path:
243
- return output_path, gr.update(visible=True, value=output_path)
244
- else:
245
- return None, gr.update(visible=False)
246
- else:
247
- if isinstance(input_file, str):
248
- input_image = Image.open(input_file)
249
- else:
250
- input_image = input_file
251
-
252
- image_tensor, original_size = preprocess_image(input_image)
253
-
254
- if torch.cuda.is_available():
255
- image_tensor = image_tensor.cuda()
256
-
257
- with torch.no_grad():
258
- prediction = model(image_tensor)
259
- disparity_tensor = prediction['out']
260
- depth_tensor = normalize_depth(disparity_tensor)
261
-
262
- depth = postprocess_depth(depth_tensor, original_size)
263
- depth_colored = create_colored_depth_map(depth, colormap_choice.lower())
264
-
265
- result = Image.fromarray(depth_colored)
266
-
267
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
268
- result.save(temp_file.name)
269
-
270
- return result, gr.update(visible=True, value=temp_file.name)
271
-
272
- except Exception as e:
273
- print(f"Error during inference: {str(e)}")
274
- return None, gr.update(visible=False)
275
-
276
-
277
- def capture_and_predict(camera_image, colormap_choice):
278
- """Capture image from camera and predict depth"""
279
- return predict_depth(camera_image, colormap_choice)
280
-
281
-
282
- with gr.Blocks(title="Depth Anything AC - Depth Estimation Demo", theme=gr.themes.Soft(), css="""
283
- .image-container {
284
- display: flex !important;
285
- align-items: flex-start !important;
286
- justify-content: center !important;
287
- }
288
- .gradio-image {
289
- vertical-align: top !important;
290
- }
291
- """) as demo:
292
- gr.Markdown("""
293
- # 🌊 Depth Anything AC - Depth Estimation Demo
294
-
295
- Upload an image or use your camera to generate corresponding depth maps! Different colors in the depth map represent different distances, allowing you to see the three-dimensional structure of the image.
296
-
297
- ## How to Use
298
- 1. **Upload Mode**: Click the upload area to select an image or video file
299
- 2. **Camera Mode**: Use your camera to capture a live image
300
- 3. Choose your preferred colormap style
301
- 4. Click the "Generate Depth Map" button
302
- 5. View the results and download
303
- """)
304
-
305
- with gr.Row():
306
- input_source = gr.Radio(
307
- choices=["Upload Image", "Upload Video", "Use Camera"],
308
- value="Upload Image",
309
- label="Input Source"
310
- )
311
- colormap_choice = gr.Dropdown(
312
- choices=["Spectral", "Inferno", "Gray"],
313
- value="Spectral",
314
- label="Colormap Style"
315
- )
316
- submit_btn = gr.Button(
317
- "🎯 Generate Depth Map",
318
- variant="primary",
319
- size="lg"
320
- )
321
-
322
- with gr.Row():
323
- gr.HTML("<h3 style='text-align: center; margin: 10px;'>πŸ“· Input Image</h3>")
324
- gr.HTML("<h3 style='text-align: center; margin: 10px;'>🌊 Depth Map Result</h3>")
325
-
326
- with gr.Row(equal_height=True):
327
- with gr.Column(scale=1):
328
- # Image input component for preview and examples
329
- upload_image = gr.Image(
330
- type="pil",
331
- height=450,
332
- visible=True,
333
- show_label=False,
334
- container=False,
335
- label="Upload Image"
336
- )
337
-
338
- # File component for video uploads
339
- upload_file = gr.File(
340
- file_types=["video"],
341
- height=200,
342
- visible=False,
343
- show_label=False,
344
- container=False,
345
- label="Upload Video"
346
- )
347
-
348
- # Camera component
349
- camera_image = gr.Image(
350
- type="pil",
351
- sources=["webcam"],
352
- height=450,
353
- visible=False,
354
- show_label=False,
355
- container=False
356
- )
357
-
358
- with gr.Column(scale=1):
359
- output_file = gr.File(
360
- height=450,
361
- show_label=False,
362
- container=False,
363
- visible=False
364
- )
365
-
366
- output_image = gr.Image(
367
- type="pil",
368
- height=450,
369
- show_label=False,
370
- container=False,
371
- visible=True
372
- )
373
-
374
- download_btn = gr.DownloadButton(
375
- label="πŸ“₯ Download Result",
376
- variant="secondary",
377
- size="sm",
378
- visible=False
379
- )
380
-
381
- def switch_input_source(source):
382
- if source == "Upload Image":
383
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
384
- elif source == "Upload Video":
385
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
386
- else: # Use Camera
387
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
388
-
389
- input_source.change(
390
- fn=switch_input_source,
391
- inputs=[input_source],
392
- outputs=[upload_image, upload_file, camera_image]
393
- )
394
-
395
- def handle_prediction(input_source, upload_img, upload_file_path, camera_img, colormap):
396
- if input_source == "Upload Image":
397
- if upload_img is None:
398
- return None, None, gr.update(visible=False), gr.update(visible=False)
399
-
400
- result, download_update = predict_depth(upload_img, colormap)
401
- return result, None, gr.update(visible=True), download_update
402
-
403
- elif input_source == "Upload Video":
404
- if upload_file_path is None:
405
- return None, None, gr.update(visible=False), gr.update(visible=False)
406
-
407
- result, download_update = predict_depth(upload_file_path, colormap)
408
-
409
- if isinstance(result, str) and is_video_file(result):
410
- return None, result, gr.update(visible=False), download_update
411
- else:
412
- return result, None, gr.update(visible=True), download_update
413
- else: # Use Camera
414
- result, download_update = predict_depth(camera_img, colormap)
415
- return result, None, gr.update(visible=True), download_update
416
-
417
- # Separate image and video examples
418
- image_examples = []
419
- video_examples = []
420
- if os.path.exists("toyset"):
421
- for img_file in ["1.png", "2.png", "good.png"]:
422
- if os.path.exists(f"toyset/{img_file}"):
423
- image_examples.append([f"toyset/{img_file}", "Spectral"])
424
-
425
- for vid_file in ["fog_2_processed_1s-6s_1.0x.mp4", "snow_processed_1s-6s_1.0x.mp4"]:
426
- if os.path.exists(f"toyset/{vid_file}"):
427
- video_examples.append([f"toyset/{vid_file}", "Spectral"])
428
-
429
- # Function to handle video example selection and auto-switch mode
430
- def handle_video_example(video_path, colormap):
431
- # Auto-switch to video mode and return the necessary updates
432
- return (
433
- "Upload Video", # input_source
434
- gr.update(visible=False), # upload_image
435
- gr.update(visible=True, value=video_path), # upload_file
436
- gr.update(visible=False) # camera_image
437
- )
438
-
439
- # Function to handle image example selection and auto-switch mode
440
- def handle_image_example(image, colormap):
441
- # Auto-switch to image mode and process the image
442
- result = predict_depth(image, colormap)
443
- output_image = result[0] if result[0] is not None else None
444
- return (
445
- "Upload Image", # input_source
446
- gr.update(visible=True, value=image), # upload_image
447
- gr.update(visible=False), # upload_file
448
- gr.update(visible=False), # camera_image
449
- output_image # output_image
450
- )
451
-
452
- if image_examples:
453
- gr.Examples(
454
- examples=image_examples,
455
- inputs=[upload_image, colormap_choice],
456
- outputs=[input_source, upload_image, upload_file, camera_image, output_image],
457
- fn=handle_image_example,
458
- cache_examples=False,
459
- label="Try these example images"
460
- )
461
-
462
- if video_examples:
463
- gr.Examples(
464
- examples=video_examples,
465
- inputs=[upload_file, colormap_choice],
466
- outputs=[input_source, upload_image, upload_file, camera_image],
467
- fn=handle_video_example,
468
- cache_examples=False,
469
- label="Try these example videos"
470
- )
471
-
472
- submit_btn.click(
473
- fn=handle_prediction,
474
- inputs=[input_source, upload_image, upload_file, camera_image, colormap_choice],
475
- outputs=[output_image, output_file, output_image, download_btn],
476
- show_progress=True
477
- )
478
-
479
- gr.Markdown("""
480
- ## πŸ“ Colormap Description
481
- - **Spectral**: Rainbow spectrum, with clear contrast between near and far
482
- - **Inferno**: Fire spectrum, warm tones
483
- - **Gray**: Classic grayscale depth representation
484
-
485
- ## πŸ“· Camera Usage Tips
486
- - Ensure camera access is allowed when prompted
487
- - Click the camera button to capture the current frame
488
- - The captured image will be used as input for depth estimation
489
-
490
- ## 🎬 Video Processing Tips
491
- - Supports multiple video formats (MP4, AVI, MOV, etc.)
492
- - Video processing may take some time, please be patient
493
- - Processing progress will be displayed in real-time
494
- - The output video will maintain the same frame rate as the input
495
- """)
496
-
497
-
498
- if __name__ == "__main__":
499
- demo.launch(
500
- server_name="0.0.0.0",
501
- server_port=7860,
502
- share=False,
503
- show_error=True
504
  )
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ import tempfile
9
+ import io
10
+ from tqdm import tqdm
11
+
12
+ from depth_anything.dpt import DepthAnything_AC
13
+
14
+
15
+ def normalize_depth(disparity_tensor):
16
+ """Standard normalization method to convert disparity to depth"""
17
+ eps = 1e-6
18
+ disparity_min = disparity_tensor.min()
19
+ disparity_max = disparity_tensor.max()
20
+ normalized_disparity = (disparity_tensor - disparity_min) / (disparity_max - disparity_min + eps)
21
+ return normalized_disparity
22
+
23
+
24
+ def load_model(model_path='checkpoints/depth_anything_AC_vits.pth', encoder='vits'):
25
+ """Load trained depth estimation model"""
26
+ model_configs = {
27
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'version': 'v2'},
28
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768], 'version': 'v2'},
29
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'version': 'v2'}
30
+ }
31
+
32
+ model = DepthAnything_AC(model_configs[encoder])
33
+
34
+ if os.path.exists(model_path):
35
+ checkpoint = torch.load(model_path, map_location='cpu')
36
+ model.load_state_dict(checkpoint, strict=False)
37
+ else:
38
+ print(f"Warning: Model file {model_path} not found")
39
+
40
+ model.eval()
41
+ if torch.cuda.is_available():
42
+ model.cuda()
43
+
44
+ return model
45
+
46
+
47
+ def preprocess_image(image, target_size=518):
48
+ """Preprocess input image"""
49
+ if isinstance(image, Image.Image):
50
+ image = np.array(image)
51
+
52
+ if len(image.shape) == 3 and image.shape[2] == 3:
53
+ pass
54
+ elif len(image.shape) == 3 and image.shape[2] == 4:
55
+ image = image[:, :, :3]
56
+
57
+ image = image.astype(np.float32) / 255.0
58
+ h, w = image.shape[:2]
59
+ scale = target_size / min(h, w)
60
+ new_h, new_w = int(h * scale), int(w * scale)
61
+
62
+ new_h = ((new_h + 13) // 14) * 14
63
+ new_w = ((new_w + 13) // 14) * 14
64
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
65
+
66
+ mean = np.array([0.485, 0.456, 0.406])
67
+ std = np.array([0.229, 0.224, 0.225])
68
+ image = (image - mean) / std
69
+
70
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float()
71
+ image = image.unsqueeze(0)
72
+
73
+ return image, (h, w)
74
+
75
+
76
+ def preprocess_image_from_array(image_array, target_size=518):
77
+ """Preprocess input image from numpy array (for video frames)"""
78
+ if len(image_array.shape) == 3 and image_array.shape[2] == 3:
79
+ # Convert BGR to RGB if needed
80
+ image = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
81
+ else:
82
+ image = image_array.astype(np.float32) / 255.0
83
+
84
+ h, w = image.shape[:2]
85
+ scale = target_size / min(h, w)
86
+ new_h, new_w = int(h * scale), int(w * scale)
87
+
88
+ new_h = ((new_h + 13) // 14) * 14
89
+ new_w = ((new_w + 13) // 14) * 14
90
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
91
+
92
+ mean = np.array([0.485, 0.456, 0.406])
93
+ std = np.array([0.229, 0.224, 0.225])
94
+ image = (image - mean) / std
95
+
96
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float()
97
+ image = image.unsqueeze(0)
98
+
99
+ return image, (h, w)
100
+
101
+
102
+ def postprocess_depth(depth_tensor, original_size):
103
+ """Post-process depth map"""
104
+ if depth_tensor.dim() == 3:
105
+ depth_tensor = depth_tensor.unsqueeze(1)
106
+ elif depth_tensor.dim() == 2:
107
+ depth_tensor = depth_tensor.unsqueeze(0).unsqueeze(1)
108
+
109
+ h, w = original_size
110
+ depth = F.interpolate(depth_tensor, size=(h, w), mode='bilinear', align_corners=True)
111
+ depth = depth.squeeze().cpu().numpy()
112
+
113
+ return depth
114
+
115
+
116
+ def create_colored_depth_map(depth, colormap='spectral'):
117
+ """Create colored depth map"""
118
+ if colormap == 'inferno':
119
+ depth_colored = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
120
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
121
+ elif colormap == 'spectral':
122
+ from matplotlib import cm
123
+ spectral_cmap = cm.get_cmap('Spectral_r')
124
+ depth_colored = (spectral_cmap(depth) * 255).astype(np.uint8)
125
+ depth_colored = depth_colored[:, :, :3]
126
+ else:
127
+ depth_colored = (depth * 255).astype(np.uint8)
128
+ depth_colored = np.stack([depth_colored] * 3, axis=2)
129
+
130
+ return depth_colored
131
+
132
+
133
+ def is_video_file(filepath):
134
+ """Check if the given file is a video file based on its extension"""
135
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
136
+ _, ext = os.path.splitext(filepath.lower())
137
+ return ext in video_extensions
138
+
139
+
140
+ print("Loading model...")
141
+ model = load_model()
142
+ print("Model loaded successfully!")
143
+
144
+
145
+ def predict_depth(input_image, colormap_choice):
146
+ """Main depth prediction function for images"""
147
+ try:
148
+ image_tensor, original_size = preprocess_image(input_image)
149
+
150
+ if torch.cuda.is_available():
151
+ image_tensor = image_tensor.cuda()
152
+
153
+ with torch.no_grad():
154
+ prediction = model(image_tensor)
155
+ disparity_tensor = prediction['out']
156
+ depth_tensor = normalize_depth(disparity_tensor)
157
+
158
+ depth = postprocess_depth(depth_tensor, original_size)
159
+
160
+ depth_colored = create_colored_depth_map(depth, colormap_choice.lower())
161
+
162
+ return Image.fromarray(depth_colored)
163
+
164
+ except Exception as e:
165
+ print(f"Error during image inference: {str(e)}")
166
+ return None
167
+
168
+
169
+ def predict_video_depth(input_video, colormap_choice, progress=gr.Progress()):
170
+ """Main depth prediction function for videos"""
171
+ if input_video is None:
172
+ return None
173
+
174
+ try:
175
+ print(f"Starting video processing: {input_video}")
176
+
177
+ # Open video file
178
+ cap = cv2.VideoCapture(input_video)
179
+ if not cap.isOpened():
180
+ print(f"Error: Cannot open video file: {input_video}")
181
+ return None
182
+
183
+ # Get video properties
184
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
185
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
186
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
187
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
188
+
189
+ print(f"Video properties: {total_frames} frames, {input_fps} FPS, {width}x{height}")
190
+
191
+ # Create temporary output video file
192
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
193
+ output_path = tmp_file.name
194
+
195
+ # Set video encoder
196
+ fourcc = cv2.VideoWriter.fourcc(*'mp4v')
197
+ out = cv2.VideoWriter(output_path, fourcc, input_fps, (width, height))
198
+
199
+ if not out.isOpened():
200
+ print(f"Error: Cannot create output video: {output_path}")
201
+ cap.release()
202
+ return None
203
+
204
+ frame_count = 0
205
+
206
+ # Process each frame
207
+ while True:
208
+ ret, frame = cap.read()
209
+ if not ret:
210
+ break
211
+
212
+ frame_count += 1
213
+ progress_percent = frame_count / total_frames
214
+ progress(progress_percent, desc=f"Processing frame {frame_count}/{total_frames}")
215
+
216
+ try:
217
+ # Preprocess current frame
218
+ image_tensor, original_size = preprocess_image_from_array(frame)
219
+ if torch.cuda.is_available():
220
+ image_tensor = image_tensor.cuda()
221
+
222
+ # Perform depth estimation
223
+ with torch.no_grad():
224
+ prediction = model(image_tensor)
225
+ disparity_tensor = prediction['out']
226
+ depth_tensor = normalize_depth(disparity_tensor)
227
+
228
+ # Postprocess depth map
229
+ depth = postprocess_depth(depth_tensor, original_size)
230
+
231
+ # Handle failed processing
232
+ if depth is None:
233
+ if depth_tensor.dim() == 1:
234
+ h, w = original_size
235
+ expected_size = h * w
236
+ if depth_tensor.shape[0] == expected_size:
237
+ depth_tensor = depth_tensor.view(1, 1, h, w)
238
+ else:
239
+ import math
240
+ side_length = int(math.sqrt(depth_tensor.shape[0]))
241
+ if side_length * side_length == depth_tensor.shape[0]:
242
+ depth_tensor = depth_tensor.view(1, 1, side_length, side_length)
243
+ depth = postprocess_depth(depth_tensor, original_size)
244
+
245
+ # Generate colored depth map
246
+ if depth is None:
247
+ print(f"Warning: Failed to process frame {frame_count}, using black frame")
248
+ depth_frame = np.zeros((height, width, 3), dtype=np.uint8)
249
+ else:
250
+ if colormap_choice.lower() == 'inferno':
251
+ depth_frame = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
252
+ elif colormap_choice.lower() == 'spectral':
253
+ from matplotlib import cm
254
+ spectral_cmap = cm.get_cmap('Spectral_r')
255
+ depth_frame = (spectral_cmap(depth) * 255).astype(np.uint8)
256
+ depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_RGBA2BGR)
257
+ else: # gray
258
+ depth_frame = (depth * 255).astype(np.uint8)
259
+ depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_GRAY2BGR)
260
+
261
+ # Write to output video
262
+ out.write(depth_frame)
263
+
264
+ except Exception as e:
265
+ print(f"Error processing frame {frame_count}: {str(e)}")
266
+ # Write black frame
267
+ black_frame = np.zeros((height, width, 3), dtype=np.uint8)
268
+ out.write(black_frame)
269
+
270
+ # Release resources
271
+ cap.release()
272
+ out.release()
273
+
274
+ print(f"Video processing completed! Output saved to: {output_path}")
275
+ return output_path
276
+
277
+ except Exception as e:
278
+ print(f"Error during video inference: {str(e)}")
279
+ return None
280
+
281
+
282
+ with gr.Blocks(title="Depth Anything AC - Depth Estimation Demo", theme=gr.themes.Soft()) as demo:
283
+ gr.Markdown("""
284
+ # 🌊 Depth Anything AC - Depth Estimation Demo
285
+
286
+ Upload an image or video and AI will generate the corresponding depth map! Different colors in the depth map represent different distances, allowing you to see the three-dimensional structure of the scene.
287
+
288
+ ## How to Use
289
+ 1. Choose image or video tab
290
+ 2. Upload your file
291
+ 3. Select your preferred colormap style
292
+ 4. Click the "Generate Depth Map" button
293
+ 5. View results and download
294
+ """)
295
+
296
+ with gr.Tabs():
297
+ # Image processing tab
298
+ with gr.TabItem("πŸ“· Image Depth Estimation"):
299
+ with gr.Row():
300
+ with gr.Column():
301
+ input_image = gr.Image(
302
+ label="Upload Image",
303
+ type="pil",
304
+ height=400
305
+ )
306
+
307
+ image_colormap_choice = gr.Dropdown(
308
+ choices=["Spectral", "Inferno", "Gray"],
309
+ value="Spectral",
310
+ label="Colormap"
311
+ )
312
+
313
+ image_submit_btn = gr.Button(
314
+ "🎯 Generate Image Depth Map",
315
+ variant="primary",
316
+ size="lg"
317
+ )
318
+
319
+ with gr.Column():
320
+ output_image = gr.Image(
321
+ label="Depth Map Result",
322
+ type="pil",
323
+ height=400
324
+ )
325
+
326
+ gr.Examples(
327
+ examples=[
328
+ ["toyset/1.png", "Spectral"],
329
+ ["toyset/2.png", "Spectral"],
330
+ ["toyset/good.png", "Spectral"],
331
+ ] if os.path.exists("toyset") else [],
332
+ inputs=[input_image, image_colormap_choice],
333
+ outputs=output_image,
334
+ fn=predict_depth,
335
+ cache_examples=False,
336
+ label="Try these example images"
337
+ )
338
+
339
+ # Video processing tab
340
+ with gr.TabItem("🎬 Video Depth Estimation"):
341
+ with gr.Row():
342
+ with gr.Column():
343
+ input_video = gr.Video(
344
+ label="Upload Video",
345
+ height=400
346
+ )
347
+
348
+ video_colormap_choice = gr.Dropdown(
349
+ choices=["Spectral", "Inferno", "Gray"],
350
+ value="Spectral",
351
+ label="Colormap"
352
+ )
353
+
354
+ video_submit_btn = gr.Button(
355
+ "🎯 Generate Video Depth Map",
356
+ variant="primary",
357
+ size="lg"
358
+ )
359
+
360
+ with gr.Column():
361
+ output_video = gr.Video(
362
+ label="Depth Map Video Result",
363
+ height=400
364
+ )
365
+
366
+ gr.Examples(
367
+ examples=[
368
+ ["toyset/fog.mp4", "Spectral"],
369
+ ["toyset/snow.mp4", "Spectral"],
370
+ ] if os.path.exists("toyset/fog.mp4") and os.path.exists("toyset/snow.mp4") else [],
371
+ inputs=[input_video, video_colormap_choice],
372
+ outputs=output_video,
373
+ fn=predict_video_depth,
374
+ cache_examples=False,
375
+ label="Try these example videos"
376
+ )
377
+
378
+ # Event bindings
379
+ image_submit_btn.click(
380
+ fn=predict_depth,
381
+ inputs=[input_image, image_colormap_choice],
382
+ outputs=output_image,
383
+ show_progress=True
384
+ )
385
+
386
+ video_submit_btn.click(
387
+ fn=predict_video_depth,
388
+ inputs=[input_video, video_colormap_choice],
389
+ outputs=output_video,
390
+ show_progress=True
391
+ )
392
+
393
+ gr.Markdown("""
394
+ ## πŸ“ Notes
395
+ - **Spectral**: Rainbow spectrum with distinct near-far contrast
396
+ - **Inferno**: Flame spectrum with warm tones
397
+ - **Gray**: Grayscale with classic effect
398
+
399
+ ## πŸ’‘ Tips
400
+ - Image processing is fast, suitable for quick preview of single images
401
+ - Video processing may take longer time, please be patient
402
+ - GPU is recommended for faster processing speed
403
+ """)
404
+
405
+
406
+ if __name__ == "__main__":
407
+ demo.launch(
408
+ server_name="0.0.0.0",
409
+ server_port=7860,
410
+ share=False,
411
+ show_error=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  )