Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import functools | |
| import os | |
| import tempfile | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from gradio_imageslider import ImageSlider | |
| from pathlib import Path | |
| from gradio.utils import get_cache_folder | |
| # Constants | |
| DEFAULT_SHARPNESS = 2 | |
| class Examples(gr.helpers.Examples): | |
| def __init__(self, *args, directory_name=None, **kwargs): | |
| super().__init__(*args, **kwargs, _initiated_directly=False) | |
| if directory_name is not None: | |
| self.cached_folder = get_cache_folder() / directory_name | |
| self.cached_file = Path(self.cached_folder) / "log.csv" | |
| self.create() | |
| def load_predictor(): | |
| """Load model predictor using torch.hub""" | |
| predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal", trust_repo=True) | |
| return predictor | |
| def process_image( | |
| predictor, | |
| path_input: str, | |
| sharpness: int = DEFAULT_SHARPNESS, | |
| data_type: str = "object" | |
| ) -> tuple: | |
| """Process single image""" | |
| if path_input is None: | |
| raise gr.Error("Please upload an image or select one from the gallery.") | |
| name_base = os.path.splitext(os.path.basename(path_input))[0] | |
| out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png") | |
| # Load and process image | |
| input_image = Image.open(path_input) | |
| normal_image = predictor(input_image, num_inference_steps=sharpness, | |
| match_input_resolution=False, data_type=data_type) | |
| normal_image.save(out_path) | |
| yield [input_image, out_path] | |
| def create_demo(): | |
| # Load model | |
| predictor = load_predictor() | |
| # Create processing functions for each data type | |
| process_object = spaces.GPU(functools.partial(process_image, predictor, data_type="object")) | |
| process_scene = spaces.GPU(functools.partial(process_image, predictor, data_type="indoor")) | |
| process_human = spaces.GPU(functools.partial(process_image, predictor, data_type="object")) | |
| # Define markdown content | |
| HEADER_MD = """ | |
| # 🎪 StableNormal V2 beta: The Not-So-Stable but Sharp Edition! 🎢 | |
| ### ✨ What's Cooking in Our Beta Kitchen? ✨ | |
| - **Zoom Zoom**: 2x faster - because waiting is boring! | |
| - **Sharp as a Tack**: Better quality for those pixel-perfect folks | |
| - **Your Way**: Tweak the sharpness slider and watch the magic happen | |
| - **Pick Your Fighter**: Objects, Scenes, or Humans - we've got you covered! | |
| ### 🎯 Pro Tips | |
| - Start with lower sharpness for a quick, stable result | |
| - Want more details? Crank it up, but watch out for those floating bits! | |
| - Sweet spot is usually around 2-3 for most images 😉 | |
| - If you get a flat result, try: | |
| * Different sharpness | |
| * Another image crop | |
| * Another mode | |
| <p align="center"> | |
| <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
| </a> | |
| <a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/Stable-X/StableNormal?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| </p> | |
| """ | |
| # Create interface | |
| demo = gr.Blocks( | |
| title="Stable Normal Estimation", | |
| css=""" | |
| .slider .inner { width: 5px; background: #FFF; } | |
| .viewport { aspect-ratio: 4/3; } | |
| .tabs button.selected { font-size: 20px !important; color: crimson !important; } | |
| h1, h2, h3 { text-align: center; display: block; } | |
| .md_feedback li { margin-bottom: 0px !important; } | |
| """ | |
| ) | |
| with demo: | |
| gr.Markdown(HEADER_MD) | |
| with gr.Tabs() as tabs: | |
| # Object Tab | |
| with gr.Tab("Object"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| object_input = gr.Image(label="Input Object Image", type="filepath") | |
| object_sharpness = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=DEFAULT_SHARPNESS, | |
| step=1, | |
| label="Sharpness (inference steps)", | |
| info="Higher values produce sharper results but take longer" | |
| ) | |
| with gr.Row(): | |
| object_submit_btn = gr.Button("Compute Normal", variant="primary") | |
| object_reset_btn = gr.Button("Reset") | |
| with gr.Column(): | |
| object_output_slider = ImageSlider( | |
| label="Normal outputs", | |
| type="filepath", | |
| show_download_button=True, | |
| show_share_button=True, | |
| interactive=False, | |
| elem_classes="slider", | |
| position=0.25, | |
| ) | |
| Examples( | |
| fn=process_object, | |
| examples=sorted([ | |
| os.path.join("files", "object", name) | |
| for name in os.listdir(os.path.join("files", "object")) | |
| if os.path.exists(os.path.join("files", "object")) | |
| ]), | |
| inputs=[object_input], | |
| outputs=[object_output_slider], | |
| cache_examples=True, | |
| directory_name="examples_object", | |
| examples_per_page=50, | |
| ) | |
| # Scene Tab | |
| with gr.Tab("Scene"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| scene_input = gr.Image(label="Input Scene Image", type="filepath") | |
| scene_sharpness = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=DEFAULT_SHARPNESS, | |
| step=1, | |
| label="Sharpness (inference steps)", | |
| info="Higher values produce sharper results but take longer" | |
| ) | |
| with gr.Row(): | |
| scene_submit_btn = gr.Button("Compute Normal", variant="primary") | |
| scene_reset_btn = gr.Button("Reset") | |
| with gr.Column(): | |
| scene_output_slider = ImageSlider( | |
| label="Normal outputs", | |
| type="filepath", | |
| show_download_button=True, | |
| show_share_button=True, | |
| interactive=False, | |
| elem_classes="slider", | |
| position=0.25, | |
| ) | |
| Examples( | |
| fn=process_scene, | |
| examples=sorted([ | |
| os.path.join("files", "scene", name) | |
| for name in os.listdir(os.path.join("files", "scene")) | |
| if os.path.exists(os.path.join("files", "scene")) | |
| ]), | |
| inputs=[scene_input], | |
| outputs=[scene_output_slider], | |
| cache_examples=True, | |
| directory_name="examples_scene", | |
| examples_per_page=50, | |
| ) | |
| # Human Tab | |
| with gr.Tab("Human"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| human_input = gr.Image(label="Input Human Image", type="filepath") | |
| human_sharpness = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=DEFAULT_SHARPNESS, | |
| step=1, | |
| label="Sharpness (inference steps)", | |
| info="Higher values produce sharper results but take longer" | |
| ) | |
| with gr.Row(): | |
| human_submit_btn = gr.Button("Compute Normal", variant="primary") | |
| human_reset_btn = gr.Button("Reset") | |
| with gr.Column(): | |
| human_output_slider = ImageSlider( | |
| label="Normal outputs", | |
| type="filepath", | |
| show_download_button=True, | |
| show_share_button=True, | |
| interactive=False, | |
| elem_classes="slider", | |
| position=0.25, | |
| ) | |
| Examples( | |
| fn=process_human, | |
| examples=sorted([ | |
| os.path.join("files", "human", name) | |
| for name in os.listdir(os.path.join("files", "human")) | |
| if os.path.exists(os.path.join("files", "human")) | |
| ]), | |
| inputs=[human_input], | |
| outputs=[human_output_slider], | |
| cache_examples=True, | |
| directory_name="examples_human", | |
| examples_per_page=50, | |
| ) | |
| # Event Handlers for Object Tab | |
| object_submit_btn.click( | |
| fn=lambda x, _: None if x else gr.Error("Please upload an image"), | |
| inputs=[object_input, object_sharpness], | |
| outputs=None, | |
| queue=False, | |
| ).success( | |
| fn=process_object, | |
| inputs=[object_input, object_sharpness], | |
| outputs=[object_output_slider], | |
| ) | |
| object_reset_btn.click( | |
| fn=lambda: (None, DEFAULT_SHARPNESS, None), | |
| inputs=[], | |
| outputs=[object_input, object_sharpness, object_output_slider], | |
| queue=False, | |
| ) | |
| # Event Handlers for Scene Tab | |
| scene_submit_btn.click( | |
| fn=lambda x, _: None if x else gr.Error("Please upload an image"), | |
| inputs=[scene_input, scene_sharpness], | |
| outputs=None, | |
| queue=False, | |
| ).success( | |
| fn=process_scene, | |
| inputs=[scene_input, scene_sharpness], | |
| outputs=[scene_output_slider], | |
| ) | |
| scene_reset_btn.click( | |
| fn=lambda: (None, DEFAULT_SHARPNESS, None), | |
| inputs=[], | |
| outputs=[scene_input, scene_sharpness, scene_output_slider], | |
| queue=False, | |
| ) | |
| # Event Handlers for Human Tab | |
| human_submit_btn.click( | |
| fn=lambda x, _: None if x else gr.Error("Please upload an image"), | |
| inputs=[human_input, human_sharpness], | |
| outputs=None, | |
| queue=False, | |
| ).success( | |
| fn=process_human, | |
| inputs=[human_input, human_sharpness], | |
| outputs=[human_output_slider], | |
| ) | |
| human_reset_btn.click( | |
| fn=lambda: (None, DEFAULT_SHARPNESS, None), | |
| inputs=[], | |
| outputs=[human_input, human_sharpness, human_output_slider], | |
| queue=False, | |
| ) | |
| return demo | |
| def main(): | |
| demo = create_demo() | |
| demo.queue(api_open=False).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |
| if __name__ == "__main__": | |
| main() |