Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import gradio as gr | |
| desc = """ | |
| <p align="center"> | |
| <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
| </a> | |
| <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| </p> | |
| <p align="justify"> | |
| Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom. | |
| </p> | |
| """ | |
| def download_code(): | |
| os.system('git clone https://github.com/prs-eth/Marigold.git') | |
| def find_first_png(directory): | |
| for file in os.listdir(directory): | |
| if file.lower().endswith(".png"): | |
| return os.path.join(directory, file) | |
| return None | |
| def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None): | |
| if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None: | |
| return path_out_png, path_out_obj, path_out_2_png | |
| path_input_dir = path_input + ".input" | |
| path_output_dir = path_input + ".output" | |
| os.makedirs(path_input_dir, exist_ok=True) | |
| os.makedirs(path_output_dir, exist_ok=True) | |
| shutil.copy(path_input, path_input_dir) | |
| os.system( | |
| f"cd Marigold && python3 run.py " | |
| f"--input_rgb_dir \"{path_input_dir}\" " | |
| f"--output_dir \"{path_output_dir}\" " | |
| f"--n_infer 10 " | |
| f"--denoise_steps 10 " | |
| ) | |
| path_out_colored = find_first_png(path_output_dir + "/depth_colored") | |
| assert path_out_colored is not None, "Processing failed" | |
| path_out_bw = find_first_png(path_output_dir + "/depth_bw") | |
| assert path_out_bw is not None, "Processing failed" | |
| return path_out_colored, path_out_bw | |
| iface = gr.Interface( | |
| title="Marigold Depth Estimation", | |
| description=desc, | |
| thumbnail="marigold_logo_square.jpg", | |
| fn=marigold_process, | |
| inputs=[ | |
| gr.Image( | |
| label="Input Image", | |
| type="filepath", | |
| ), | |
| gr.File( | |
| label="Predicted depth (red-near, blue-far)", | |
| visible=False, | |
| ), | |
| gr.File( | |
| label="Predicted depth (16-bit PNG)", | |
| visible=False, | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Image( | |
| label="Predicted depth (red-near, blue-far)", | |
| type="pil", | |
| ), | |
| gr.Image( | |
| label="Predicted depth (16-bit PNG)", | |
| type="pil", | |
| elem_classes="imgdownload", | |
| ), | |
| ], | |
| allow_flagging="never", | |
| examples=[ | |
| [ | |
| os.path.join(os.path.dirname(__file__), "files/bee.jpg"), | |
| os.path.join(os.path.dirname(__file__), "files/bee_vis.png"), | |
| os.path.join(os.path.dirname(__file__), "files/bee_pred.png"), | |
| ], | |
| [ | |
| os.path.join(os.path.dirname(__file__), "files/cat.jpg"), | |
| os.path.join(os.path.dirname(__file__), "files/cat_vis.png"), | |
| os.path.join(os.path.dirname(__file__), "files/cat_pred.png"), | |
| ], | |
| [ | |
| os.path.join(os.path.dirname(__file__), "files/swings.jpg"), | |
| os.path.join(os.path.dirname(__file__), "files/swings_vis.png"), | |
| os.path.join(os.path.dirname(__file__), "files/swings_pred.png"), | |
| ], | |
| ], | |
| css=""" | |
| .viewport { | |
| aspect-ratio: 4/3; | |
| } | |
| .imgdownload { | |
| height: 32px; | |
| } | |
| """, | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| download_code() | |
| iface.queue().launch(server_name="0.0.0.0", server_port=7860) | |