zino36 commited on
Commit
6a56c0e
·
verified ·
1 Parent(s): aebd06c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -62,51 +62,46 @@ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](ht
62
 
63
  @spaces.GPU
64
  def predict_depth(image):
65
- #return model.infer_image(image)
66
  return pipe(Image.fromarray(image))["depth"]
67
 
68
- with gr.Blocks(css=css) as demo:
69
  gr.Markdown(title)
70
  gr.Markdown(description)
71
  gr.Markdown("### Depth Prediction demo")
72
 
73
  with gr.Row():
74
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
75
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
76
  submit = gr.Button(value="Compute Depth")
77
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
78
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
79
 
80
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
81
 
82
  def on_submit(image):
83
- original_image = image.copy()
84
 
85
- h, w = image.shape[:2]
86
-
87
- depth = predict_depth(image[:, :, ::-1])
88
-
89
- raw_depth = Image.fromarray(depth)
90
  tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
91
  raw_depth.save(tmp_raw_depth.name)
92
 
93
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
94
- depth = depth.astype(np.uint8)
95
- colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
 
96
 
97
- gray_depth = Image.fromarray(depth)
98
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
99
  gray_depth.save(tmp_gray_depth.name)
100
 
101
- return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
102
 
103
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
104
 
105
- example_files = os.listdir('assets/examples')
106
- example_files.sort()
107
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
108
  examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
109
 
110
-
111
  if __name__ == '__main__':
112
  demo.queue().launch(share=True)
 
62
 
63
  @spaces.GPU
64
  def predict_depth(image):
 
65
  return pipe(Image.fromarray(image))["depth"]
66
 
67
+ with gr.Blocks(css="") as demo:
68
  gr.Markdown(title)
69
  gr.Markdown(description)
70
  gr.Markdown("### Depth Prediction demo")
71
 
72
  with gr.Row():
73
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
74
+ depth_image_slider = gr.Image(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
75
  submit = gr.Button(value="Compute Depth")
76
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
77
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
78
 
79
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
80
 
81
  def on_submit(image):
82
+ depth = predict_depth(image)
83
 
84
+ # Convert depth to images and save
85
+ raw_depth = Image.fromarray(depth.astype('uint16'))
 
 
 
86
  tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
87
  raw_depth.save(tmp_raw_depth.name)
88
 
89
+ # Normalize depth for color mapping
90
+ normalized_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
91
+ normalized_depth = normalized_depth.astype(np.uint8)
92
+ colored_depth = (cmap(normalized_depth)[:, :, :3] * 255).astype(np.uint8)
93
 
94
+ gray_depth = Image.fromarray(normalized_depth)
95
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
96
  gray_depth.save(tmp_gray_depth.name)
97
 
98
+ return [(Image.fromarray(colored_depth), tmp_gray_depth.name, tmp_raw_depth.name)]
99
 
100
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
101
 
102
+ example_files = sorted(os.listdir('assets/examples'))
 
103
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
104
  examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
105
 
 
106
  if __name__ == '__main__':
107
  demo.queue().launch(share=True)