Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| An interactive demo of text-guided shape generation. | |
| """ | |
| from pathlib import Path | |
| from typing import Literal | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from salad.utils.spaghetti_util import ( | |
| get_mesh_from_spaghetti, | |
| generate_zc_from_sj_gaus, | |
| load_mesher, | |
| load_spaghetti, | |
| ) | |
| import hydra | |
| from omegaconf import OmegaConf | |
| import torch | |
| from pytorch_lightning import seed_everything | |
| def load_model( | |
| model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], | |
| device, | |
| ): | |
| checkpoint_dir = Path(__file__).parent / "checkpoints" | |
| c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") | |
| model = hydra.utils.instantiate(c) | |
| ckpt = torch.load( | |
| checkpoint_dir / f"{model_class}/state_only.ckpt", | |
| map_location=device, | |
| ) | |
| model.load_state_dict(ckpt) | |
| model.eval() | |
| for p in model.parameters(): p.requires_grad_(False) | |
| model = model.to(device) | |
| return model | |
| def run_inference(prompt: str): | |
| """The entry point of the demo.""" | |
| device: torch.device = torch.device("cuda") | |
| """Device to run the demo on.""" | |
| seed: int = 63 | |
| """Random seed for reproducibility.""" | |
| # set random seed | |
| seed_everything(seed) | |
| # load SPAGHETTI and mesher | |
| spaghetti = load_spaghetti(device) | |
| mesher = load_mesher(device) | |
| # load SALAD | |
| lang_phase1_model = load_model("lang_phase1", device) | |
| lang_phase2_model = load_model("phase2", device) | |
| lang_phase1_model._build_dataset("val") | |
| # run phase 1 | |
| extrinsics = lang_phase1_model.sampling_gaussians([prompt]) | |
| # run phase 2 | |
| intrinsics = lang_phase2_model.sample(extrinsics) | |
| # generate mesh | |
| zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) | |
| vertices, faces = get_mesh_from_spaghetti( | |
| spaghetti, | |
| mesher, | |
| zcs[0], | |
| res=256, | |
| ) | |
| # plot | |
| figure = go.Figure( | |
| data=[ | |
| go.Mesh3d( | |
| x=vertices[:, 0], # flip front-back | |
| y=-vertices[:, 2], | |
| z=vertices[:, 1], | |
| i=faces[:, 0], | |
| j=faces[:, 1], | |
| k=faces[:, 2], | |
| color="gray", | |
| ) | |
| ], | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| ) | |
| ), | |
| ) | |
| return figure | |
| if __name__ == "__main__": | |
| # create UI | |
| demo = gr.Interface( | |
| fn=run_inference, | |
| inputs="text", | |
| outputs=gr.Plot(), | |
| title="SALAD: Text-Guided Shape Generation", | |
| description="Describe a chair", | |
| examples=[ | |
| "an office chair", | |
| "a chair with armrests", | |
| "a chair without armrests", | |
| ] | |
| ) | |
| # initiate | |
| demo.queue(max_size=30) | |
| demo.launch() | |