import os from pathlib import Path import torch from fastai.vision.all import * import gradio as gr ####################### # Hugging Face flags # ####################### HF_TOKEN = os.getenv("HF_TOKEN") try: from gradio.flagging import HuggingFaceDatasetSaver # type: ignore hf_writer: gr.FlaggingCallback | None = HuggingFaceDatasetSaver( repo_id="savtadepth-flags-V2", token=HF_TOKEN ) allow_flagging: str | bool = "manual" except (ImportError, AttributeError): hf_writer = None allow_flagging = "never" # hide flag button if callback unavailable ############ # DVC # ############ PROD_MODEL_PATH = "src/models" TRAIN_PATH = "src/data/processed/train/bathroom" TEST_PATH = "src/data/processed/test/bathroom" if Path(".dvc").is_dir(): print("Running DVC") if os.system(f"dvc pull {PROD_MODEL_PATH} {TRAIN_PATH} {TEST_PATH}") != 0: raise SystemExit("dvc pull failed") os.system("rm -rf .dvc") ####################### # Data & Learner # ####################### class ImageImageDataLoaders(DataLoaders): """Create DataLoaders for image→image tasks.""" @classmethod @delegates(DataLoaders.from_dblock) def from_label_func( cls, path: Path, filenames, label_func, valid_pct: float = 0.2, seed: int | None = None, item_transforms=None, batch_transforms=None, **kwargs, ): dblock = DataBlock( blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)), get_y=label_func, splitter=RandomSplitter(valid_pct, seed=seed), item_tfms=item_transforms, batch_tfms=batch_transforms, ) return cls.from_dblock(dblock, filenames, path=path, **kwargs) def get_y_fn(x: Path) -> Path: return Path(str(x).replace(".jpg", "_depth.png")) def create_data(data_path: Path): fnames = get_files(data_path / "train", extensions=".jpg") return ImageImageDataLoaders.from_label_func( data_path / "train", seed=42, bs=4, num_workers=0, filenames=fnames, label_func=get_y_fn, ) data = create_data(Path("src/data/processed")) learner = unet_learner( data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path="src/", ) learner.load("model") ##################### # Inference Logic # ##################### def predict_depth(input_img: PILImage) -> PILImageBW: depth, *_ = learner.predict(input_img) return PILImageBW.create(depth).convert("L") ##################### # Gradio UI # ##################### title = "📷 SavtaDepth WebApp" description_md = ( """

Upload an RGB image on the left and get a grayscale depth map on the right.

""" ) footer_html = ( """

Project on DAGsHubGoogle Colab Demo

""" ) examples = [["examples/00008.jpg"], ["examples/00045.jpg"]] input_component = gr.Image(width=640, height=480, label="Input RGB") output_component = gr.Image(label="Predicted Depth", image_mode="L") with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo: gr.Markdown(f"

{title}

") gr.HTML(description_md) gr.Interface( fn=predict_depth, inputs=input_component, outputs=output_component, allow_flagging=allow_flagging, flagging_options=["incorrect", "worst", "ambiguous"], flagging_callback=hf_writer, examples=examples, cache_examples=False, ) gr.HTML(footer_html) if __name__ == "__main__": demo.queue().launch()