et-viewer / app.py
azizinaghsh's picture
add character position input
f5aaf3c
raw
history blame
7.49 kB
import spaces
from functools import partial
from typing import Any, Callable, Dict
import clip
import gradio as gr
from gradio_rerun import Rerun
import numpy as np
import trimesh
import rerun as rr
import torch
from utils.common_viz import init, get_batch
from utils.random_utils import set_random_seed
from utils.rerun import log_sample
from src.diffuser import Diffuser
from src.datasets.multimodal_dataset import MultimodalDataset
# ------------------------------------------------------------------------------------- #
batch_size, num_cams, num_verts = None, None, None
SAMPLE_IDS = [
"2011_KAeAqaA0Llg_00005_00001",
"2011_F_EuMeT2wBo_00014_00001",
"2011_MCkKihQrNA4_00014_00000",
]
LABEL_TO_IDS = {
"right": 0,
"static": 1,
"complex": 2,
}
EXAMPLES = [
"While the character moves right, the camera trucks right.",
"While the character moves right, the camera performs a push in.",
"While the character moves right, the camera performs a pull out.",
"Movement: shortArcShotRight Easing: easeInOutQuad Frames: 30 Camera Angle: birdsEyeView Shot Type: mediumShot Subject Index: 0",
"Movement: fullZoomIn Easing: easeInOutSine Frames: 30 Camera Angle: highAngle Shot Type: closeUp",
"Movement: pedestalDown Easing: easeOutExpo Frames: 30 Camera Angle: mediumAngle Shot Type: longShot", # noqa
"Movement: dollyIn Easing: easeOutBounce Frames: 30 Camera Angle: mediumAngle Shot Type: longShot", # noqa
]
DEFAULT_TEXT = [
"While the character moves right, the camera [...].",
"Movement: dollyIn Easing: easeOutBounce Frames: 30 [...].",
"Movement: shortArcShotRight Easing: easeInOutQuad [...]. "
"Movement: fullZoomIn Easing: easeInOutSine [...].",
]
HEADER = """
<div align="center">
<h1 style='text-align: center'>E.T. the Exceptional Trajectories</h2>
<a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a>
<a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a>
<a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a>
<a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a>
<a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a>
</div>
<div align="center">
<a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
</div>
<br/>
"""
# ------------------------------------------------------------------------------------- #
def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
num_frames, num_faces = vertices.shape[0], faces.shape[-2]
faces = faces.expand(num_frames, num_faces, 3)
normals = [
trimesh.Trimesh(vertices=v, faces=f, process=False).vertex_normals
for v, f in zip(vertices, faces)
]
normals = torch.from_numpy(np.stack(normals))
return normals
@spaces.GPU
def generate(
prompt: str,
seed: int,
guidance_weight: float,
character_position: list,
# ----------------------- #
dataset: MultimodalDataset,
device: torch.device,
diffuser: Diffuser,
clip_model: clip.model.CLIP,
) -> Dict[str, Any]:
diffuser.to(device)
clip_model.to(device)
# Set arguments
set_random_seed(seed)
diffuser.gen_seeds = np.array([seed])
diffuser.guidance_weight = guidance_weight
# Inference
sample_id = SAMPLE_IDS[0] # Default to the first sample ID
seq_feat = diffuser.net.model.clip_sequential
batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
batch["character_position"] = torch.tensor(character_position, device=device)
with torch.no_grad():
out = diffuser.predict_step(batch, 0)
# Run visualization
padding_mask = out["padding_mask"][0].to(bool).cpu()
padded_traj = out["gen_samples"][0].cpu()
traj = padded_traj[padding_mask]
char_traj = out["char_feat"][0].cpu()
padded_vertices = out["char_raw"]["char_vertices"][0]
vertices = padded_vertices[padding_mask]
faces = out["char_raw"]["char_faces"][0]
normals = get_normals(vertices, faces)
fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy()
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
caption = out["caption_raw"][0]
rr.init(f"{sample_id}")
rr.save(".tmp_gr.rrd")
log_sample(
root_name="world",
traj=traj.numpy(),
char_traj=char_traj.numpy(),
K=K,
vertices=vertices.numpy(),
faces=faces.numpy(),
normals=normals.numpy(),
caption=caption,
mesh_masks=None,
)
return "./.tmp_gr.rrd"
# ------------------------------------------------------------------------------------- #
def launch_app(gen_fn: Callable):
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
with gr.Blocks(theme=theme) as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column(scale=3):
with gr.Column(scale=2):
char_position = gr.Textbox(
placeholder="Enter character position as [x, y, z]",
show_label=True,
label="Character Position (3D vector)",
value="[0.0, 0.0, 0.0]",
)
text = gr.Textbox(
placeholder="Type the camera motion you want to generate",
show_label=True,
label="Text prompt",
value=DEFAULT_TEXT[0],
)
seed = gr.Number(value=33, label="Seed")
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
with gr.Column(scale=1):
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
examples = gr.Examples(
examples=[[x, None, None] for x in EXAMPLES],
inputs=[text],
)
with gr.Row():
output = Rerun()
def load_example(example_id):
processed_example = examples.non_none_processed_examples[example_id]
return gr.utils.resolve_singleton(processed_example)
inputs = [text, seed, guidance, char_position]
examples.dataset.click(
load_example,
inputs=[examples.dataset],
outputs=examples.inputs_with_examples,
show_progress=False,
postprocess=False,
queue=False,
).then(fn=gen_fn, inputs=inputs, outputs=[output])
btn.click(fn=gen_fn, inputs=inputs, outputs=[output])
text.submit(fn=gen_fn, inputs=inputs, outputs=[output])
demo.queue().launch(share=False)
# ------------------------------------------------------------------------------------- #
diffuser, clip_model, dataset, device = init("config")
generate_sample = partial(
generate,
dataset=dataset,
device=device,
diffuser=diffuser,
clip_model=clip_model,
)
launch_app(generate_sample)