ghost233lism commited on
Commit
b6ee1cf
Β·
verified Β·
1 Parent(s): a63396a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -60
app.py CHANGED
@@ -20,6 +20,15 @@ def normalize_depth(disparity_tensor):
20
  return normalized_disparity
21
 
22
 
 
 
 
 
 
 
 
 
 
23
  def load_model(model_path='checkpoints/depth_anything_AC_vits.pth', encoder='vits'):
24
  """Load trained depth estimation model"""
25
  model_configs = {
@@ -44,16 +53,26 @@ def load_model(model_path='checkpoints/depth_anything_AC_vits.pth', encoder='vit
44
 
45
 
46
  def preprocess_image(image, target_size=518):
47
- """Preprocess input image"""
48
- if isinstance(image, Image.Image):
 
 
 
 
 
49
  image = np.array(image)
 
 
 
 
 
 
50
 
51
  if len(image.shape) == 3 and image.shape[2] == 3:
52
  pass
53
  elif len(image.shape) == 3 and image.shape[2] == 4:
54
  image = image[:, :, :3]
55
 
56
- image = image.astype(np.float32) / 255.0
57
  h, w = image.shape[:2]
58
  scale = target_size / min(h, w)
59
  new_h, new_w = int(h * scale), int(w * scale)
@@ -103,100 +122,321 @@ def create_colored_depth_map(depth, colormap='spectral'):
103
  return depth_colored
104
 
105
 
106
- print("Loading model...")
107
- model = load_model()
108
- print("Model loaded successfully!")
109
-
110
-
111
- def predict_depth(input_image, colormap_choice):
112
- """Main depth prediction function"""
113
  try:
114
- image_tensor, original_size = preprocess_image(input_image)
 
 
 
 
 
 
 
 
 
115
 
116
- if torch.cuda.is_available():
117
- image_tensor = image_tensor.cuda()
118
 
119
- with torch.no_grad():
120
- prediction = model(image_tensor)
121
- disparity_tensor = prediction['out']
122
- depth_tensor = normalize_depth(disparity_tensor)
123
 
124
- depth = postprocess_depth(depth_tensor, original_size)
 
125
 
126
- depth_colored = create_colored_depth_map(depth, colormap_choice.lower())
 
 
127
 
128
- return Image.fromarray(depth_colored)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  except Exception as e:
131
  print(f"Error during inference: {str(e)}")
132
  return None
133
 
134
 
135
- with gr.Blocks(title="Depth Anything AC - Depth Estimation Demo", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  gr.Markdown("""
137
  # 🌊 Depth Anything AC - Depth Estimation Demo
138
 
139
- Upload an image and AI will generate the corresponding depth map! Different colors in the depth map represent different distances, allowing you to see the three-dimensional structure of the image.
140
 
141
  ## How to Use
142
- 1. Click the upload area to select an image
143
- 2. Choose your preferred colormap style
144
- 3. Click the "Generate Depth Map" button
145
- 4. View the results and download
 
146
  """)
147
 
148
  with gr.Row():
149
- with gr.Column():
150
- input_image = gr.Image(
151
- label="Upload Image",
152
- type="pil",
153
- height=400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
 
156
- colormap_choice = gr.Dropdown(
157
- choices=["Spectral", "Inferno", "Gray"],
158
- value="Spectral",
159
- label="Colormap"
 
 
 
 
160
  )
161
 
162
- submit_btn = gr.Button(
163
- "🎯 Generate Depth Map",
164
- variant="primary",
165
- size="lg"
 
 
166
  )
167
 
168
- with gr.Column():
169
  output_image = gr.Image(
170
- label="Depth Map Result",
171
  type="pil",
172
- height=400
 
 
 
 
 
 
 
 
 
 
173
  )
174
 
175
- gr.Examples(
176
- examples=[
177
- ["toyset/1.png", "Spectral"],
178
- ["toyset/2.png", "Spectral"],
179
- ["toyset/good.png", "Spectral"],
180
- ] if os.path.exists("toyset") else [],
181
- inputs=[input_image, colormap_choice],
182
- outputs=output_image,
183
- fn=predict_depth,
184
- cache_examples=False,
185
- label="Try these example images"
186
  )
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  submit_btn.click(
189
- fn=predict_depth,
190
- inputs=[input_image, colormap_choice],
191
- outputs=output_image,
192
  show_progress=True
193
  )
194
 
195
  gr.Markdown("""
196
- ## πŸ“ Notes
197
- - **Spectral**: Rainbow spectrum with distinct near-far contrast
198
- - **Inferno**: Flame spectrum with warm tones
199
- - **Gray**: Grayscale with classic effect
 
 
 
 
 
 
 
 
 
 
 
200
  """)
201
 
202
 
 
20
  return normalized_disparity
21
 
22
 
23
+ def is_video_file(filepath):
24
+ """Check if the given file is a video file based on its extension"""
25
+ if filepath is None:
26
+ return False
27
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
28
+ _, ext = os.path.splitext(filepath.lower())
29
+ return ext in video_extensions
30
+
31
+
32
  def load_model(model_path='checkpoints/depth_anything_AC_vits.pth', encoder='vits'):
33
  """Load trained depth estimation model"""
34
  model_configs = {
 
53
 
54
 
55
  def preprocess_image(image, target_size=518):
56
+ """Preprocess input image (supports both PIL Image and numpy array)"""
57
+ if isinstance(image, str):
58
+ raw_image = cv2.imread(image)
59
+ if raw_image is None:
60
+ raise ValueError(f"Cannot read image: {image}")
61
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
62
+ elif isinstance(image, Image.Image):
63
  image = np.array(image)
64
+ image = image.astype(np.float32) / 255.0
65
+ elif isinstance(image, np.ndarray):
66
+ if image.dtype == np.uint8:
67
+ image = image.astype(np.float32) / 255.0
68
+ else:
69
+ raise ValueError(f"Unsupported image type: {type(image)}")
70
 
71
  if len(image.shape) == 3 and image.shape[2] == 3:
72
  pass
73
  elif len(image.shape) == 3 and image.shape[2] == 4:
74
  image = image[:, :, :3]
75
 
 
76
  h, w = image.shape[:2]
77
  scale = target_size / min(h, w)
78
  new_h, new_w = int(h * scale), int(w * scale)
 
122
  return depth_colored
123
 
124
 
125
+ def process_video(video_path, colormap_choice, progress=gr.Progress()):
126
+ """Process video file for depth estimation"""
 
 
 
 
 
127
  try:
128
+ print(f"Processing video: {video_path}")
129
+
130
+ cap = cv2.VideoCapture(video_path)
131
+ if not cap.isOpened():
132
+ raise ValueError(f"Cannot open video file: {video_path}")
133
+
134
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
135
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
136
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
137
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
 
139
+ print(f"Video properties: {total_frames} frames, {input_fps} FPS, {width}x{height}")
 
140
 
141
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
142
+ output_path = temp_output.name
143
+ temp_output.close()
 
144
 
145
+ fourcc = cv2.VideoWriter.fourcc(*'mp4v')
146
+ out = cv2.VideoWriter(output_path, fourcc, input_fps, (width, height))
147
 
148
+ if not out.isOpened():
149
+ cap.release()
150
+ raise ValueError("Cannot create output video file")
151
 
152
+ frame_count = 0
153
+
154
+ try:
155
+ while True:
156
+ ret, frame = cap.read()
157
+ if not ret:
158
+ break
159
+
160
+ frame_count += 1
161
+
162
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
163
+
164
+ try:
165
+ image_tensor, original_size = preprocess_image(frame_rgb)
166
+
167
+ if torch.cuda.is_available():
168
+ image_tensor = image_tensor.cuda()
169
+
170
+ with torch.no_grad():
171
+ prediction = model(image_tensor)
172
+ disparity_tensor = prediction['out']
173
+ depth_tensor = normalize_depth(disparity_tensor)
174
+
175
+ depth = postprocess_depth(depth_tensor, original_size)
176
+
177
+ if depth is None:
178
+ if depth_tensor.dim() == 1:
179
+ h, w = original_size
180
+ expected_size = h * w
181
+ if depth_tensor.shape[0] == expected_size:
182
+ depth_tensor = depth_tensor.view(1, 1, h, w)
183
+ else:
184
+ import math
185
+ side_length = int(math.sqrt(depth_tensor.shape[0]))
186
+ if side_length * side_length == depth_tensor.shape[0]:
187
+ depth_tensor = depth_tensor.view(1, 1, side_length, side_length)
188
+ depth = postprocess_depth(depth_tensor, original_size)
189
+
190
+ if depth is None:
191
+ print(f"Warning: Frame {frame_count} processing failed, using black frame")
192
+ depth_frame = np.zeros((height, width, 3), dtype=np.uint8)
193
+ else:
194
+ if colormap_choice.lower() == 'inferno':
195
+ depth_frame = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
196
+ elif colormap_choice.lower() == 'spectral':
197
+ from matplotlib import cm
198
+ spectral_cmap = cm.get_cmap('Spectral_r')
199
+ depth_frame = (spectral_cmap(depth) * 255).astype(np.uint8)
200
+ depth_frame = depth_frame[:, :, :3]
201
+ depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2BGR)
202
+ else:
203
+ depth_frame = (depth * 255).astype(np.uint8)
204
+ depth_frame = cv2.cvtColor(depth_frame, cv2.COLOR_GRAY2BGR)
205
+
206
+ out.write(depth_frame)
207
+
208
+ except Exception as e:
209
+ print(f"Error processing frame {frame_count}: {str(e)}")
210
+ black_frame = np.zeros((height, width, 3), dtype=np.uint8)
211
+ out.write(black_frame)
212
+
213
+ progress((frame_count / total_frames), f"Processing progress: {frame_count}/{total_frames} frames")
214
+
215
+ except Exception as e:
216
+ print(f"Unexpected error during video processing: {str(e)}")
217
+ finally:
218
+ cap.release()
219
+ out.release()
220
+
221
+ print(f"Video processing completed! Output saved to: {output_path}")
222
+ return output_path
223
+
224
+ except Exception as e:
225
+ print(f"Video processing failed: {str(e)}")
226
+ return None
227
+
228
+
229
+ print("Loading model...")
230
+ model = load_model()
231
+ print("Model loaded successfully!")
232
+
233
+
234
+ def predict_depth(input_file, colormap_choice):
235
+ """Main depth prediction function for both images and videos"""
236
+ try:
237
+ if input_file is None:
238
+ return None, gr.update(visible=False)
239
+
240
+ if is_video_file(input_file):
241
+ output_path = process_video(input_file, colormap_choice)
242
+ if output_path:
243
+ return output_path, gr.update(visible=True, value=output_path)
244
+ else:
245
+ return None, gr.update(visible=False)
246
+ else:
247
+ if isinstance(input_file, str):
248
+ input_image = Image.open(input_file)
249
+ else:
250
+ input_image = input_file
251
+
252
+ image_tensor, original_size = preprocess_image(input_image)
253
+
254
+ if torch.cuda.is_available():
255
+ image_tensor = image_tensor.cuda()
256
+
257
+ with torch.no_grad():
258
+ prediction = model(image_tensor)
259
+ disparity_tensor = prediction['out']
260
+ depth_tensor = normalize_depth(disparity_tensor)
261
+
262
+ depth = postprocess_depth(depth_tensor, original_size)
263
+ depth_colored = create_colored_depth_map(depth, colormap_choice.lower())
264
+
265
+ result = Image.fromarray(depth_colored)
266
+
267
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
268
+ result.save(temp_file.name)
269
+
270
+ return result, gr.update(visible=True, value=temp_file.name)
271
 
272
  except Exception as e:
273
  print(f"Error during inference: {str(e)}")
274
  return None
275
 
276
 
277
+ def capture_and_predict(camera_image, colormap_choice):
278
+ """Capture image from camera and predict depth"""
279
+ return predict_depth(camera_image, colormap_choice)
280
+
281
+
282
+ with gr.Blocks(title="Depth Anything AC - Depth Estimation Demo", theme=gr.themes.Soft(), css="""
283
+ .image-container {
284
+ display: flex !important;
285
+ align-items: flex-start !important;
286
+ justify-content: center !important;
287
+ }
288
+ .gradio-image {
289
+ vertical-align: top !important;
290
+ }
291
+ """) as demo:
292
  gr.Markdown("""
293
  # 🌊 Depth Anything AC - Depth Estimation Demo
294
 
295
+ Upload an image or use your camera to generate corresponding depth maps! Different colors in the depth map represent different distances, allowing you to see the three-dimensional structure of the image.
296
 
297
  ## How to Use
298
+ 1. **Upload Mode**: Click the upload area to select an image or video file
299
+ 2. **Camera Mode**: Use your camera to capture a live image
300
+ 3. Choose your preferred colormap style
301
+ 4. Click the "Generate Depth Map" button
302
+ 5. View the results and download
303
  """)
304
 
305
  with gr.Row():
306
+ input_source = gr.Radio(
307
+ choices=["Upload Image", "Use Camera"],
308
+ value="Upload Image",
309
+ label="Input Source"
310
+ )
311
+ colormap_choice = gr.Dropdown(
312
+ choices=["Spectral", "Inferno", "Gray"],
313
+ value="Spectral",
314
+ label="Colormap Style"
315
+ )
316
+ submit_btn = gr.Button(
317
+ "🎯 Generate Depth Map",
318
+ variant="primary",
319
+ size="lg"
320
+ )
321
+
322
+ with gr.Row():
323
+ gr.HTML("<h3 style='text-align: center; margin: 10px;'>πŸ“· Input Image</h3>")
324
+ gr.HTML("<h3 style='text-align: center; margin: 10px;'>🌊 Depth Map Result</h3>")
325
+
326
+ with gr.Row(equal_height=True):
327
+ with gr.Column(scale=1):
328
+ upload_file = gr.File(
329
+ file_types=["image", "video"],
330
+ height=450,
331
+ visible=True,
332
+ show_label=False,
333
+ container=False,
334
+ label="Upload Image or Video"
335
  )
336
 
337
+ # Camera component
338
+ camera_image = gr.Image(
339
+ type="pil",
340
+ sources=["webcam"],
341
+ height=450,
342
+ visible=False,
343
+ show_label=False,
344
+ container=False
345
  )
346
 
347
+ with gr.Column(scale=1):
348
+ output_file = gr.File(
349
+ height=450,
350
+ show_label=False,
351
+ container=False,
352
+ visible=False
353
  )
354
 
 
355
  output_image = gr.Image(
 
356
  type="pil",
357
+ height=450,
358
+ show_label=False,
359
+ container=False,
360
+ visible=True
361
+ )
362
+
363
+ download_btn = gr.DownloadButton(
364
+ label="πŸ“₯ Download Result",
365
+ variant="secondary",
366
+ size="sm",
367
+ visible=False
368
  )
369
 
370
+ def switch_input_source(source):
371
+ if source == "Upload Image":
372
+ return gr.update(visible=True), gr.update(visible=False)
373
+ else:
374
+ return gr.update(visible=False), gr.update(visible=True)
375
+
376
+ input_source.change(
377
+ fn=switch_input_source,
378
+ inputs=[input_source],
379
+ outputs=[upload_file, camera_image]
 
380
  )
381
 
382
+ def handle_prediction(input_source, upload_file_path, camera_img, colormap):
383
+ if input_source == "Upload Image":
384
+ if upload_file_path is None:
385
+ return None, None, gr.update(visible=False), gr.update(visible=False)
386
+
387
+ result, download_update = predict_depth(upload_file_path, colormap)
388
+
389
+ if isinstance(result, str) and is_video_file(result):
390
+ return None, result, gr.update(visible=False), download_update
391
+ else:
392
+ return result, None, gr.update(visible=True), download_update
393
+ else:
394
+ result, download_update = predict_depth(camera_img, colormap)
395
+ return result, None, gr.update(visible=True), download_update
396
+
397
+ example_files = []
398
+ if os.path.exists("toyset"):
399
+ for img_file in ["1.png", "2.png", "good.png"]:
400
+ if os.path.exists(f"toyset/{img_file}"):
401
+ example_files.append([f"toyset/{img_file}", "Spectral"])
402
+
403
+ for vid_file in ["fog_2_processed_1s-6s_1.0x.mp4", "snow_processed_1s-6s_1.0x.mp4"]:
404
+ if os.path.exists(f"toyset/{vid_file}"):
405
+ example_files.append([f"toyset/{vid_file}", "Spectral"])
406
+
407
+ if example_files:
408
+ gr.Examples(
409
+ examples=example_files,
410
+ inputs=[upload_file, colormap_choice],
411
+ outputs=[output_image, output_file],
412
+ fn=lambda file_path, colormap: predict_depth(file_path, colormap),
413
+ cache_examples=False,
414
+ label="Try these example files"
415
+ )
416
+
417
  submit_btn.click(
418
+ fn=handle_prediction,
419
+ inputs=[input_source, upload_file, camera_image, colormap_choice],
420
+ outputs=[output_image, output_file, output_image, download_btn],
421
  show_progress=True
422
  )
423
 
424
  gr.Markdown("""
425
+ ## πŸ“ Colormap Description
426
+ - **Spectral**: Rainbow spectrum, with clear contrast between near and far
427
+ - **Inferno**: Fire spectrum, warm tones
428
+ - **Gray**: Classic grayscale depth representation
429
+
430
+ ## πŸ“· Camera Usage Tips
431
+ - Ensure camera access is allowed when prompted
432
+ - Click the camera button to capture the current frame
433
+ - The captured image will be used as input for depth estimation
434
+
435
+ ## 🎬 Video Processing Tips
436
+ - Supports multiple video formats (MP4, AVI, MOV, etc.)
437
+ - Video processing may take some time, please be patient
438
+ - Processing progress will be displayed in real-time
439
+ - The output video will maintain the same frame rate as the input
440
  """)
441
 
442