diff --git a/README.md b/README.md
index f8d425e235252183851c832664b0c503d8ef7e74..e8fbf66b67eebc5701d46b1f139bd267e04cb80e 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,226 @@
----
-title: Worldmem
-emoji: 🐢
-colorFrom: indigo
-colorTo: yellow
-sdk: gradio
-sdk_version: 5.23.3
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion
+
+#### [[Project Website]](https://boyuan.space/diffusion-forcing) [[Paper]](https://arxiv.org/abs/2407.01392)
+
+[Boyuan Chen1](https://boyuan.space/), [Diego Martí Monsó2](https://www.linkedin.com/in/diego-marti/?originalSubdomain=de), [ Yilun Du1](https://yilundu.github.io/), [Max Simchowitz1](https://msimchowitz.github.io/), [Russ Tedrake1](https://groups.csail.mit.edu/locomotion/russt.html), [Vincent Sitzmann1](https://www.vincentsitzmann.com/)
+1MIT 2Technical University of Munich
+
+This is the v1.5 code base for our paper [Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion](https://boyuan.space/diffusion-forcing). The **main** branch contains our latest reimplementation with temporal attention (recommended) while the **paper** branch contains RNN code used by original paper for reproduction purpose.
+
+Diffusion Forcing v2 is coming very soon! There is a stronger technique to achieve infinite, consistent video generation uniquely enabled by diffusion forcing. We are actively investigating that so please stay tuned. We will also release latent diffusion code by then that allows you to scale up to higher resolution / longer videos!
+
+
+
+```
+@misc{chen2024diffusionforcingnexttokenprediction,
+ title={Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion},
+ author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
+ year={2024},
+ eprint={2407.01392},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2407.01392},
+}
+```
+
+# Project Instructions
+
+## Setup
+
+If you want to use our latest improved implementation for video and planning with temporal attention instead of RNN, stay on this branch. If you are instead interested in reproducing claims by orignal paper, switch to the branch used by original paper via `git checkout paper`.
+
+Run `conda create python=3.10 -n diffusion-forcing` to create environment.
+Run `conda activate diffusion-forcing` to activate this environment.
+
+Install dependencies for time series, video and robotics:
+
+```
+pip install -r requirements.txt
+```
+
+[Sign up](https://wandb.ai/site) a wandb account for cloud logging and checkpointing. In command line, run `wandb login` to login.
+
+Then modify the wandb entity in `configurations/config.yaml` to your wandb account.
+
+Optionally, if you want to do maze planning, install the following complicated dependencies due to outdated dependencies of d4rl. This involves first installing mujoco 210 and then run
+
+```
+pip install -r extra_requirements.txt
+```
+
+## Quick start with pretrained checkpoints
+
+Since dataset is huge, we provide a mini subset and pre-trained checkpoints for you to quickly test out our model! To do so, download mini dataset and checkpoints from [here](https://drive.google.com/file/d/1xAOQxWcLzcFyD4zc0_rC9jGXe_uaHb7b/view?usp=sharing) to project root and extract with `tar -xzvf quickstart_atten.tar.gz`. Files shall appear in `data` and `outputs/xxx.ckpt`. Make sure you also git pull upstream to use latest version of code if you forked before ckpt release!
+
+Then run the following commands and go to the wandb panel to see the results.
+
+### Video Prediction:
+
+Our visualization is side by side, with prediction on the left and ground truth on the right. However, ground truth is expected to not align with prediction since the sequence is highly stochastic. Ground truth is provided to provide an idea about quality only.
+
+Autoregressively generate minecraft video with 1x the length it's trained on:
+`python -m main +name=sample_minecraft_pretrained load=outputs/minecraft.ckpt experiment.tasks=[validation]`
+
+To let the model roll out **longer than it's trained on**, simply append `dataset.validation_multiplier=8` to the above commands, and it will rollout `8x` longer than maximum sequence length it's trained on.
+
+The above checkpoint is trained for 100K steps with small number of frames. We've already verified diffusion forcing works in latent diffusion setting and can be extended to many more tokens without sacrificing compositionally (with some addition techniques outside this repo)! Stay tuned for our next project!
+
+### Maze Planning:
+
+The maze planning setting is changed a bit as we gain more insighs, please see corresponding paragraphs in training section for details. We haven't reimplemented MCTG yet, but you can already see nice visualizations on wandb log.
+
+Medium Maze
+
+`python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] load=outputs/maze2d_medium_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=3 +name=maze2d_medium_x_sampling`
+
+Large Maze
+
+`python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_large_x_sampling`
+
+We also explored a couple more settings but haven't reimplemented everything in original paper yet. If you are interestted in those checkpoints, see the source code of this README file for ckpt loading instructions that's commented out.
+
+
+
+## Training
+
+### Video
+
+Video prediction requires downloading giant datasets. First, if you downloaded the mini subset following `Quick start with pretrained checkpoints` section, delete the mini subset folders `data/minecraft` and `data/dmlab` because we have to download the whole dataset this time. We've coded in python that it will download the dataset for you it doesn't already exist. Due to the slowness of the [source](https://github.com/wilson1yan/teco), this may take a couple days. If you prefer to do it yourself via bash script, please refer to the bash scripts in original [TECO dataset](https://github.com/wilson1yan/teco) and use `dmlab.sh` and `minecraft.sh` in their Dataset section of README, any maybe split bash script into parallel scripts.
+
+Then just run the corresponding commands:
+
+#### Minecraft
+
+`python -m main +name=your_experiment_name algorithm=df_video dataset=video_minecraft`
+
+#### DMLab
+
+`python -m main +name=your_experiment_name algorithm=df_video dataset=video_dmlab algorithm.weight_decay=1e-3 algorithm.diffusion.architecture.network_size=48 algorithm.diffusion.architecture.attn_dim_head=32 algorithm.diffusion.architecture.attn_resolutions=[8,16,32,64] algorithm.diffusion.beta_schedule=cosine`
+
+#### No causal masking
+
+Simply append `algorithm.causal=False` to your command.
+
+#### Play with sampling
+
+Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with `load=`. Then, run the exact training command with `experiment.tasks=[validation] load={wandb_run_id}` to load a checkpoint and experiment with sampling.
+
+To see how you can roll out longer than the sequence is trained on, you can find instructions in `quick start with pretrained checkpoints` section. Keep in mind that rolling out infinitely without sliding window is a property of original RNN implementation on `paper` branch, and this version has to use sliding window since it's temporal attention.
+
+By default, we run autoregressive sampling with stablization. To sample next 2 tokens jointly, you can append the following to the above command: `algorithm.scheduling_matrix=full_sequence algorithm.chunk_size=2`.
+
+## Maze Planning
+
+For those who only wish to reproduce the original paper instead of transformer architecture, please checkout`paper` branch of the code instead.
+
+**Medium Maze**
+
+`python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] +name=maze2d_medium_x`
+
+**Large Maze**
+
+`python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] +name=maze2d_large_x`
+
+**Run planning after model is trained**
+
+Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with `load=`. To sample, simply append `load={wandb_id_of_above_runs} experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_sampling` to above command after trained. Feel free to tune the `guidance_scale` from 1 - 5.
+
+This version of maze planning uses a different version of diffusion forcing from original paper - while doing the follow up to diffusion forcing, we realized that training with independent noise actually constructed a smooth interpolation between causal and non-causal models too, since we can just masked out future by complete noise (fully causal) or some noise (interpolation). The best thing is, you can still account for causal uncertainty via pyramoid sampling in this setting, by masking out tokens at different noise levels, and you can still have flexible horizon because you can tell the model that padded entries are pure noise, a unique ability of diffusion forcing.
+
+We also reflected a bit about the environment and concluded that the original metric isn't necessarily a good metric, because maze planning should reward those who can plan the fastest route to goal, not a slow walking agent that goes there at the end of episode. The dataset never contains data of staying at the goal, so agents are supposed to walk away after reaching the goal. I think [Diffuser](https://arxiv.org/abs/2205.09991) had an unfair advantage of just generating slow plans, that happend to let the agent stay in the neighbour hood of goal for longer and got very high reward, exploiting flaws in the environment design (a good design would involve penalty of longer time taken to reach goal). So, in this version of code, we just optimize for flexible horizon planning that tries to reach goal asap, and the planner will automatically come back to goal if it left the goal since staying is never in dataset. You can see new metrics we designed in wandb logging interface.
+
+## Timeseries and Robotics
+
+Please checkout `paper` branch for the code used by original paper. If I have time later, I will reimplement these two domains with transformer as well to complete this branch.
+
+# Change Log
+
+| Data | Notes |
+| --------- | :---------------------------------------------------------------------------------------------: |
+| Jul/30/24 | Upgrade RNN to temporal attention, move orignal code to 'paper' branch |
+| Jul/03/24 | Initial release of the code. Email me if you have questions or find any errors in this version. |
+
+# Infra instructions
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
+
+All experiments can be launched via `python -m main +name=xxxx {options}` where you can fine more details later in this article.
+
+The code base will automatically use cuda or your Macbook M1 GPU when available.
+
+For slurm clusters e.g. mit supercloud, you can run `python -m main cluster=mit_supercloud {options}` on login node.
+It will automatically generate slurm scripts and run them for you on a compute node. Even if compute nodes are offline,
+the script will still automatically sync wandb logging to cloud with <1min latency. It's also easy to add your own slurm
+by following the `Add slurm clusters` section.
+
+## Modify for your own project
+
+First, create a new repository with this template. Make sure the new repository has the name you want to use for wandb
+logging.
+
+Add your method and baselines in `algorithms` following the `algorithms/README.md` as well as the example code in
+`algorithms/diffusion_forcing/df_video.py`. For pytorch experiments, write your algorithm as a [pytorch lightning](https://github.com/Lightning-AI/lightning)
+`pl.LightningModule` which has extensive
+[documentation](https://lightning.ai/docs/pytorch/stable/). For a quick start, read "Define a LightningModule" in this [link](https://lightning.ai/docs/pytorch/stable/starter/introduction.html). Finally, add a yaml config file to `configurations/algorithm` imitating that of `configurations/algorithm/df_video.yaml`, for each algorithm you added.
+
+Add your dataset in `datasets` following the `datasets/README.md` as well as the example code in
+`datasets/video`. Finally, add a yaml config file to `configurations/dataset` imitating that of
+`configurations/dataset/video_dmlab.yaml`, for each dataset you added.
+
+Add your experiment in `experiments` following the `experiments/README.md` or following the example code in
+`experiments/exp_video.py`. Then register your experiment in `experiments/__init__.py`.
+Finally, add a yaml config file to `configurations/experiment` imitating that of
+`configurations/experiment/exp_video.yaml`, for each experiment you added.
+
+Modify `configurations/config.yaml` to set `algorithm` to the yaml file you want to use in `configurations/algorithm`;
+set `experiment` to the yaml file you want to use in `configurations/experiment`; set `dataset` to the yaml file you
+want to use in `configurations/dataset`, or to `null` if no dataset is needed; Notice the fields should not contain the
+`.yaml` suffix.
+
+You are all set!
+
+`cd` into your project root. Now you can launch your new experiment with `python main.py +name=`. You can run baselines or
+different datasets by add arguments like `algorithm=xxx` or `dataset=xxx`. You can also override any `yaml` configurations by following the next section.
+
+One special note, if your want to define a new task for your experiment, (e.g. other than `training` and `test`) you can define it as a method in your experiment class and use `experiment.tasks=[task_name]` to run it. Let's say you have a `generate_dataset` task before the task `training` and you implemented it in experiment class, you can then run `python -m main +name xxxx experiment.tasks=[generate_dataset,training]` to execute it before training.
+
+## Pass in arguments
+
+We use [hydra](https://hydra.cc) instead of `argparse` to configure arguments at every code level. You can both write a static config in `configuration` folder or, at runtime,
+[override part of yur static config](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/) with command line arguments.
+
+For example, arguments `algorithm=example_classifier experiment.lr=1e-3` will override the `lr` variable in `configurations/experiment/example_classifier.yaml`. The argument `wandb.mode` will override the `mode` under `wandb` namesspace in the file `configurations/config.yaml`.
+
+All static config and runtime override will be logged to cloud automatically.
+
+## Resume a checkpoint & logging
+
+For machine learning experiments, all checkpoints and logs are logged to cloud automatically so you can resume them on another server. Simply append `resume={wandb_run_id}` to your command line arguments to resume it. The run_id can be founded in a url of a wandb run in wandb dashboard. By default, latest checkpoint in a run is stored indefinitely and earlier checkpoints in the run will be deleted after 5 days to save your storage.
+
+On the other hand, sometimes you may want to start a new run with different run id but still load a prior ckpt. This can be done by setting the `load={wandb_run_id / ckpt path}` flag.
+
+## Load a checkpoint to eval
+
+The argument `experiment.tasks=[task_name1,task_name2]` (note the `[]` brackets here needed) allows to select a sequence of tasks to execute, such as `training`, `validation` and `test`. Therefore, for testing a machine learning ckpt, you may run `python -m main load={your_wandb_run_id} experiment.tasks=[test]`.
+
+More generally, the task names are the corresponding method names of your experiment class. For `BaseLightningExperiment`, we already defined three methods `training`, `validation` and `test` for you, but you can also define your own tasks by creating methods to your experiment class under intended task names.
+
+## Debug
+
+We provide a useful debug flag which you can enable by `python main.py debug=True`. This will enable numerical error tracking as well as setting `cfg.debug` to `True` for your experiments, algorithms and datasets class. However, this debug flag will make ML code very slow as it automatically tracks all parameter / gradients!
+
+## Add slurm clusters
+
+It's very easy to add your own slurm clusters via adding a yaml file in `configurations/cluster`. You can take a look
+at `configurations/cluster/mit_vision.yaml` for example.
diff --git a/__pycache__/app.cpython-310.pyc b/__pycache__/app.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05f92083da5270e34b53c7e0d5c7a51e04ec8114
Binary files /dev/null and b/__pycache__/app.cpython-310.pyc differ
diff --git a/algorithms/README.md b/algorithms/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dad42ab7d3ee7065f3e2d4d89ff3c1287dd122d3
--- /dev/null
+++ b/algorithms/README.md
@@ -0,0 +1,21 @@
+# algorithms
+
+`algorithms` folder is designed to contain implementation of algorithms or models.
+Content in `algorithms` can be loosely grouped components (e.g. models) or an algorithm has already has all
+components chained together (e.g. Lightning Module, RL algo).
+You should create a folder name after your own algorithm or baselines in it.
+
+Two example can be found in `examples` subfolder.
+
+The `common` subfolder is designed to contain general purpose classes that's useful for many projects, e.g MLP.
+
+You should not run any `.py` file from algorithms folder.
+Instead, you write unit tests / debug python files in `debug` and launch script in `experiments`.
+
+You are discouraged from putting visualization utilities in algorithms, as those should go to `utils` in project root.
+
+Each algorithm class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/algorithm` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/algorithms/__init__.py b/algorithms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/algorithms/__pycache__/__init__.cpython-310.pyc b/algorithms/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54139dc2cb06c4686faa42d7c6f995cd1fd331d3
Binary files /dev/null and b/algorithms/__pycache__/__init__.cpython-310.pyc differ
diff --git a/algorithms/common/README.md b/algorithms/common/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b689b4ae4c42569b9eaca4540a413ca500ce63ad
--- /dev/null
+++ b/algorithms/common/README.md
@@ -0,0 +1,5 @@
+THis folder contains models / algorithms that are considered general for many algorithms.
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/algorithms/common/__init__.py b/algorithms/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/algorithms/common/__pycache__/__init__.cpython-310.pyc b/algorithms/common/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5f90e1aae68f2b0efd33d28419e5160be6d3734
Binary files /dev/null and b/algorithms/common/__pycache__/__init__.cpython-310.pyc differ
diff --git a/algorithms/common/__pycache__/base_pytorch_algo.cpython-310.pyc b/algorithms/common/__pycache__/base_pytorch_algo.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00dad5ace62164daeffb744867557c33bae0e10f
Binary files /dev/null and b/algorithms/common/__pycache__/base_pytorch_algo.cpython-310.pyc differ
diff --git a/algorithms/common/base_algo.py b/algorithms/common/base_algo.py
new file mode 100644
index 0000000000000000000000000000000000000000..753c7b43bd9b080b5d9eac4209e5ccdf519af24f
--- /dev/null
+++ b/algorithms/common/base_algo.py
@@ -0,0 +1,22 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from omegaconf import DictConfig
+
+
+class BaseAlgo(ABC):
+ """
+ A base class for generic algorithms.
+ """
+
+ def __init__(self, cfg: DictConfig):
+ super().__init__()
+ self.cfg = cfg
+ self.debug = self.cfg.debug
+
+ @abstractmethod
+ def run(*args: Any, **kwargs: Any) -> Any:
+ """
+ Run the algorithm.
+ """
+ raise NotImplementedError
diff --git a/algorithms/common/base_pytorch_algo.py b/algorithms/common/base_pytorch_algo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d058b97c25654fd1bc1bd763e3cc08434a38963c
--- /dev/null
+++ b/algorithms/common/base_pytorch_algo.py
@@ -0,0 +1,253 @@
+from abc import ABC, abstractmethod
+import warnings
+from typing import Any, Union, Sequence, Optional
+
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+from omegaconf import DictConfig
+import lightning.pytorch as pl
+import torch
+import numpy as np
+from PIL import Image
+import wandb
+import einops
+
+
+class BasePytorchAlgo(pl.LightningModule, ABC):
+ """
+ A base class for Pytorch algorithms using Pytorch Lightning.
+ See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
+ """
+
+ def __init__(self, cfg: DictConfig):
+ super().__init__()
+ self.cfg = cfg
+ self.debug = self.cfg.debug
+ self._build_model()
+
+ @abstractmethod
+ def _build_model(self):
+ """
+ Create all pytorch nn.Modules here.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
+ r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
+ logger.
+
+ Args:
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
+ batch_idx: The index of this batch.
+ dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
+
+ Return:
+ Any of these options:
+ - :class:`~torch.Tensor` - The loss tensor
+ - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
+ - ``None`` - Skip to the next batch. This is only supported for automatic optimization.
+ This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
+
+ In this step you'd normally do the forward pass and calculate the loss for a batch.
+ You can also do fancier things like multiple forward passes or something model specific.
+
+ Example::
+
+ def training_step(self, batch, batch_idx):
+ x, y, z = batch
+ out = self.encoder(x)
+ loss = self.loss(out, x)
+ return loss
+
+ To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
+
+ .. code-block:: python
+
+ def __init__(self):
+ super().__init__()
+ self.automatic_optimization = False
+
+
+ # Multiple optimizers (e.g.: GANs)
+ def training_step(self, batch, batch_idx):
+ opt1, opt2 = self.optimizers()
+
+ # do training_step with encoder
+ ...
+ opt1.step()
+ # do training_step with decoder
+ ...
+ opt2.step()
+
+ Note:
+ When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
+ normalized by ``accumulate_grad_batches`` internally.
+
+ """
+ return super().training_step(*args, **kwargs)
+
+ def configure_optimizers(self):
+ """
+ Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
+ https://lightning.ai/docs/pytorch/stable/common/optimization.html
+ """
+ parameters = self.parameters()
+ return torch.optim.Adam(parameters, lr=self.cfg.lr)
+
+ def log_video(
+ self,
+ key: str,
+ video: Union[np.ndarray, torch.Tensor],
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
+ fps: int = 5,
+ format: str = "mp4",
+ ):
+ """
+ Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
+
+ Args:
+ video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
+ (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
+ or [0, 1] otherwise.
+ mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
+ std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
+ key: the name of the video.
+ fps: the frame rate of the video.
+ format: the format of the video. Can be either "mp4" or "gif".
+ """
+
+ if isinstance(video, torch.Tensor):
+ video = video.detach().cpu().numpy()
+
+ expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
+ if std is not None:
+ if isinstance(std, (float, int)):
+ std = [std] * 3
+ if isinstance(std, torch.Tensor):
+ std = std.detach().cpu().numpy()
+ std = np.array(std).reshape(*expand_shape)
+ video = video * std
+ if mean is not None:
+ if isinstance(mean, (float, int)):
+ mean = [mean] * 3
+ if isinstance(mean, torch.Tensor):
+ mean = mean.detach().cpu().numpy()
+ mean = np.array(mean).reshape(*expand_shape)
+ video = video + mean
+
+ if video.dtype != np.uint8:
+ video = np.clip(video, a_min=0, a_max=1) * 255
+ video = video.astype(np.uint8)
+
+ self.logger.experiment.log(
+ {
+ key: wandb.Video(video, fps=fps, format=format),
+ },
+ step=self.global_step,
+ )
+
+ def log_image(
+ self,
+ key: str,
+ image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
+ **kwargs: Any,
+ ):
+ """
+ Log image(s) using WandbLogger.
+ Args:
+ key: the name of the video.
+ image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
+ mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
+ std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
+ kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
+ """
+ if isinstance(image, Image.Image):
+ image = [image]
+ elif len(image) and not isinstance(image[0], Image.Image):
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+
+ if len(image.shape) == 3:
+ image = image[None]
+
+ if image.shape[1] == 3:
+ if image.shape[-1] == 3:
+ warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
+ image = einops.rearrange(image, "b c h w -> b h w c")
+
+ if std is not None:
+ if isinstance(std, (float, int)):
+ std = [std] * 3
+ if isinstance(std, torch.Tensor):
+ std = std.detach().cpu().numpy()
+ std = np.array(std)[None, None, None]
+ image = image * std
+ if mean is not None:
+ if isinstance(mean, (float, int)):
+ mean = [mean] * 3
+ if isinstance(mean, torch.Tensor):
+ mean = mean.detach().cpu().numpy()
+ mean = np.array(mean)[None, None, None]
+ image = image + mean
+
+ if image.dtype != np.uint8:
+ image = np.clip(image, a_min=0.0, a_max=1.0) * 255
+ image = image.astype(np.uint8)
+ image = [img for img in image]
+
+ self.logger.log_image(key=key, images=image, **kwargs)
+
+ def log_gradient_stats(self):
+ """Log gradient statistics such as the mean or std of norm."""
+
+ with torch.no_grad():
+ grad_norms = []
+ gpr = [] # gradient-to-parameter ratio
+ for param in self.parameters():
+ if param.grad is not None:
+ grad_norms.append(torch.norm(param.grad).item())
+ gpr.append(torch.norm(param.grad) / torch.norm(param))
+ if len(grad_norms) == 0:
+ return
+ grad_norms = torch.tensor(grad_norms)
+ gpr = torch.tensor(gpr)
+ self.log_dict(
+ {
+ "train/grad_norm/min": grad_norms.min(),
+ "train/grad_norm/max": grad_norms.max(),
+ "train/grad_norm/std": grad_norms.std(),
+ "train/grad_norm/mean": grad_norms.mean(),
+ "train/grad_norm/median": torch.median(grad_norms),
+ "train/gpr/min": gpr.min(),
+ "train/gpr/max": gpr.max(),
+ "train/gpr/std": gpr.std(),
+ "train/gpr/mean": gpr.mean(),
+ "train/gpr/median": torch.median(gpr),
+ }
+ )
+
+ def register_data_mean_std(
+ self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
+ ):
+ """
+ Register mean and std of data as tensor buffer.
+
+ Args:
+ mean: the mean of data.
+ std: the std of data.
+ namespace: the namespace of the registered buffer.
+ """
+ for k, v in [("mean", mean), ("std", std)]:
+ if isinstance(v, str):
+ if v.endswith(".npy"):
+ v = torch.from_numpy(np.load(v))
+ elif v.endswith(".pt"):
+ v = torch.load(v)
+ else:
+ raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
+ else:
+ v = torch.tensor(v)
+ self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
diff --git a/algorithms/common/metrics/__init__.py b/algorithms/common/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61c0d28943dd29d7aeb4b1121939b573f7989e9b
--- /dev/null
+++ b/algorithms/common/metrics/__init__.py
@@ -0,0 +1,3 @@
+from .fid import FrechetInceptionDistance
+from .lpips import LearnedPerceptualImagePatchSimilarity
+from .fvd import FrechetVideoDistance
diff --git a/algorithms/common/metrics/__pycache__/__init__.cpython-310.pyc b/algorithms/common/metrics/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d65237655955a23f5c74c789547b29e46a76ce0c
Binary files /dev/null and b/algorithms/common/metrics/__pycache__/__init__.cpython-310.pyc differ
diff --git a/algorithms/common/metrics/__pycache__/fid.cpython-310.pyc b/algorithms/common/metrics/__pycache__/fid.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0f737b27768ffa7dafd9342aa6d095e00bd7894
Binary files /dev/null and b/algorithms/common/metrics/__pycache__/fid.cpython-310.pyc differ
diff --git a/algorithms/common/metrics/__pycache__/fvd.cpython-310.pyc b/algorithms/common/metrics/__pycache__/fvd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19d17eff832cb05b6aa75eb3b21a44adfe202290
Binary files /dev/null and b/algorithms/common/metrics/__pycache__/fvd.cpython-310.pyc differ
diff --git a/algorithms/common/metrics/__pycache__/lpips.cpython-310.pyc b/algorithms/common/metrics/__pycache__/lpips.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5bdc2627e9d37294013bf0b21c75e839b5fdf0f
Binary files /dev/null and b/algorithms/common/metrics/__pycache__/lpips.cpython-310.pyc differ
diff --git a/algorithms/common/metrics/fid.py b/algorithms/common/metrics/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..428a621a58807767650101026576335090d10fc0
--- /dev/null
+++ b/algorithms/common/metrics/fid.py
@@ -0,0 +1 @@
+from torchmetrics.image.fid import FrechetInceptionDistance
diff --git a/algorithms/common/metrics/fvd.py b/algorithms/common/metrics/fvd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a502055eff0b19ab8724d1d5cbee38ab85a8ee7c
--- /dev/null
+++ b/algorithms/common/metrics/fvd.py
@@ -0,0 +1,158 @@
+"""
+Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v
+Verified to be the same as tf version by https://github.com/universome/fvd-comparison
+"""
+
+import io
+import re
+import requests
+import html
+import hashlib
+import urllib
+import urllib.request
+from typing import Any, List, Tuple, Union, Dict
+import scipy
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def open_url(
+ url: str,
+ num_attempts: int = 10,
+ verbose: bool = True,
+ return_filename: bool = False,
+) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match("^[a-z]+://", url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith("file://"):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r"^/[a-zA-Z]:", filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [
+ html.unescape(link)
+ for link in content_str.split('"')
+ if "export=download" in link
+ ]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError(
+ "Google Drive download quota exceeded -- please try again later"
+ )
+
+ match = re.search(
+ r'filename="([^"]*)"',
+ res.headers.get("Content-Disposition", ""),
+ )
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
+
+
+def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
+ mu_gen, sigma_gen = compute_stats(feats_fake)
+ mu_real, sigma_real = compute_stats(feats_real)
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(
+ np.dot(sigma_gen, sigma_real), disp=False
+ ) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+
+ return float(fid)
+
+
+def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ mu = feats.mean(axis=0) # [d]
+ sigma = np.cov(feats, rowvar=False) # [d, d]
+
+ return mu, sigma
+
+
+class FrechetVideoDistance(nn.Module):
+ def __init__(self):
+ super().__init__()
+ detector_url = (
+ "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
+ )
+ # Return raw features before the softmax layer.
+ self.detector_kwargs = dict(rescale=False, resize=True, return_features=True)
+ with open_url(detector_url, verbose=False) as f:
+ self.detector = torch.jit.load(f).eval()
+
+ @torch.no_grad()
+ def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor):
+ """
+ :param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width)
+ :param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width)
+ :return:
+ """
+ n_frames, batch_size, c, h, w = videos_fake.shape
+ if n_frames < 2:
+ raise ValueError("Video must have more than 1 frame for FVD")
+
+ videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous()
+ videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous()
+
+ # detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1
+ feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy()
+ feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy()
+
+ return compute_fvd(feats_fake, feats_real)
diff --git a/algorithms/common/metrics/lpips.py b/algorithms/common/metrics/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..34fc01b7c3375c3efd2b4b3929866104471022eb
--- /dev/null
+++ b/algorithms/common/metrics/lpips.py
@@ -0,0 +1 @@
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
diff --git a/algorithms/common/models/__init__.py b/algorithms/common/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/algorithms/common/models/cnn.py b/algorithms/common/models/cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1720814f03c70bab96faf4f2382ebb39b16bf83
--- /dev/null
+++ b/algorithms/common/models/cnn.py
@@ -0,0 +1,141 @@
+import math
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def is_square_of_two(num):
+ if num <= 0:
+ return False
+ return num & (num - 1) == 0
+
+class CnnEncoder(nn.Module):
+ """
+ Simple cnn encoder that encodes a 64x64 image to embeddings
+ """
+ def __init__(self, embedding_size, activation_function='relu'):
+ super().__init__()
+ self.act_fn = getattr(F, activation_function)
+ self.embedding_size = embedding_size
+ self.fc = nn.Linear(1024, self.embedding_size)
+ self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
+ self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
+ self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
+ self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
+ self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
+
+ def forward(self, observation):
+ batch_size = observation.shape[0]
+ hidden = self.act_fn(self.conv1(observation))
+ hidden = self.act_fn(self.conv2(hidden))
+ hidden = self.act_fn(self.conv3(hidden))
+ hidden = self.act_fn(self.conv4(hidden))
+ hidden = self.fc(hidden.view(batch_size, 1024))
+ return hidden
+
+
+class CnnDecoder(nn.Module):
+ """
+ Simple Cnn decoder that decodes an embedding to 64x64 images
+ """
+ def __init__(self, embedding_size, activation_function='relu'):
+ super().__init__()
+ self.act_fn = getattr(F, activation_function)
+ self.embedding_size = embedding_size
+ self.fc = nn.Linear(embedding_size, 128)
+ self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2)
+ self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
+ self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
+ self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)
+ self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
+
+ def forward(self, embedding):
+ batch_size = embedding.shape[0]
+ hidden = self.fc(embedding)
+ hidden = hidden.view(batch_size, 128, 1, 1)
+ hidden = self.act_fn(self.conv1(hidden))
+ hidden = self.act_fn(self.conv2(hidden))
+ hidden = self.act_fn(self.conv3(hidden))
+ observation = self.conv4(hidden)
+ return observation
+
+
+class FullyConvEncoder(nn.Module):
+ """
+ Simple fully convolutional encoder, with 2D input and 2D output
+ """
+ def __init__(self,
+ input_shape=(3, 64, 64),
+ embedding_shape=(8, 16, 16),
+ activation_function='relu',
+ init_channels=16,
+ ):
+ super().__init__()
+
+ assert len(input_shape) == 3, "input_shape must be a tuple of length 3"
+ assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
+ assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square"
+ assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square"
+ assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape"
+ assert is_square_of_two(init_channels), "init_channels must be a square of 2"
+
+ depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1
+ channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
+ self.act_fn = getattr(F, activation_function)
+
+ self.downs = nn.ModuleList([])
+ self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1))
+
+ for i in range(1, depth):
+ self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i],
+ kernel_size=3, stride=2, padding=1))
+
+ # Bottleneck layer
+ self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0))
+
+ def forward(self, observation):
+ hidden = observation
+ for layer in self.downs:
+ hidden = self.act_fn(layer(hidden))
+ return hidden
+
+
+class FullyConvDecoder(nn.Module):
+ """
+ Simple fully convolutional decoder, with 2D input and 2D output
+ """
+ def __init__(self,
+ embedding_shape=(8, 16, 16),
+ output_shape=(3, 64, 64),
+ activation_function='relu',
+ init_channels=16,
+ ):
+ super().__init__()
+
+ assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
+ assert len(output_shape) == 3, "output_shape must be a tuple of length 3"
+ assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square"
+ assert embedding_shape[1] == embedding_shape[2], "input_shape must be square"
+ assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape"
+ assert is_square_of_two(init_channels), "init_channels must be a square of 2"
+
+ depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1
+ channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
+ self.act_fn = getattr(F, activation_function)
+
+ self.ups = nn.ModuleList([])
+ self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1],
+ kernel_size=1, stride=1, padding=0))
+
+ for i in range(1, depth):
+ self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1],
+ kernel_size=3, stride=2, padding=1, output_padding=1))
+
+ self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0],
+ kernel_size=3, stride=1, padding=1)
+
+ def forward(self, embedding):
+ hidden = embedding
+ for layer in self.ups:
+ hidden = self.act_fn(layer(hidden))
+
+ return self.output_layer(hidden)
diff --git a/algorithms/common/models/mlp.py b/algorithms/common/models/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3273eacbe8a153e10e3cf0ede6ba3145e6f81c4
--- /dev/null
+++ b/algorithms/common/models/mlp.py
@@ -0,0 +1,22 @@
+from typing import Type, Optional
+
+import torch
+from torch import nn as nn
+
+
+class SimpleMlp(nn.Module):
+ """
+ A class for very simple multi layer perceptron
+ """
+ def __init__(self, in_dim=2, out_dim=1, hidden_dim=64, n_layers=2,
+ activation: Type[nn.Module] = nn.ReLU, output_activation: Optional[Type[nn.Module]] = None):
+ super(SimpleMlp, self).__init__()
+ layers = [nn.Linear(in_dim, hidden_dim), activation()]
+ layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()] * (n_layers - 2))
+ layers.append(nn.Linear(hidden_dim, out_dim))
+ if output_activation:
+ layers.append(output_activation())
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.net(x)
diff --git a/algorithms/worldmem/__init__.py b/algorithms/worldmem/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e9dfe8b751f5c2aa77b5b91fbfb909f0c35faf5
--- /dev/null
+++ b/algorithms/worldmem/__init__.py
@@ -0,0 +1,2 @@
+from .df_video import WorldMemMinecraft
+from .pose_prediction import PosePrediction
\ No newline at end of file
diff --git a/algorithms/worldmem/__pycache__/__init__.cpython-310.pyc b/algorithms/worldmem/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16e4822597c28afb9553714b21b76aad30f2d36d
Binary files /dev/null and b/algorithms/worldmem/__pycache__/__init__.cpython-310.pyc differ
diff --git a/algorithms/worldmem/__pycache__/df_base.cpython-310.pyc b/algorithms/worldmem/__pycache__/df_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb1f65ea50097779f1052f008a20b11959758010
Binary files /dev/null and b/algorithms/worldmem/__pycache__/df_base.cpython-310.pyc differ
diff --git a/algorithms/worldmem/__pycache__/df_video.cpython-310.pyc b/algorithms/worldmem/__pycache__/df_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d7da4ba5ff180cd26e92ceb7023be7e818ed934
Binary files /dev/null and b/algorithms/worldmem/__pycache__/df_video.cpython-310.pyc differ
diff --git a/algorithms/worldmem/__pycache__/pose_prediction.cpython-310.pyc b/algorithms/worldmem/__pycache__/pose_prediction.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6d7c55d9f62c98c32aa71329545dadfafecc7cb
Binary files /dev/null and b/algorithms/worldmem/__pycache__/pose_prediction.cpython-310.pyc differ
diff --git a/algorithms/worldmem/df_base.py b/algorithms/worldmem/df_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..685a964854ff17a206c4e0c5d11fad02886b2c3b
--- /dev/null
+++ b/algorithms/worldmem/df_base.py
@@ -0,0 +1,307 @@
+"""
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
+template [repo](https://github.com/buoyancy99/research-template).
+By its MIT license, you must keep the above sentence in `README.md`
+and the `LICENSE` file to credit the author.
+"""
+
+from typing import Optional
+from tqdm import tqdm
+from omegaconf import DictConfig
+import numpy as np
+import torch
+import torch.nn.functional as F
+from typing import Any
+from einops import rearrange
+
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+
+from algorithms.common.base_pytorch_algo import BasePytorchAlgo
+from .models.diffusion import Diffusion
+
+
+class DiffusionForcingBase(BasePytorchAlgo):
+ def __init__(self, cfg: DictConfig):
+ self.cfg = cfg
+ self.x_shape = cfg.x_shape
+ self.frame_stack = cfg.frame_stack
+ self.x_stacked_shape = list(self.x_shape)
+ self.x_stacked_shape[0] *= cfg.frame_stack
+ self.guidance_scale = cfg.guidance_scale
+ self.context_frames = cfg.context_frames
+ self.chunk_size = cfg.chunk_size
+ self.action_cond_dim = cfg.action_cond_dim
+ self.causal = cfg.causal
+
+ self.uncertainty_scale = cfg.uncertainty_scale
+ self.timesteps = cfg.diffusion.timesteps
+ self.sampling_timesteps = cfg.diffusion.sampling_timesteps
+ self.clip_noise = cfg.diffusion.clip_noise
+
+ self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip)
+
+ self.validation_step_outputs = []
+ super().__init__(cfg)
+
+ def _build_model(self):
+ self.diffusion_model = Diffusion(
+ x_shape=self.x_stacked_shape,
+ action_cond_dim=self.action_cond_dim,
+ is_causal=self.causal,
+ cfg=self.cfg.diffusion,
+ )
+ self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
+
+ def configure_optimizers(self):
+ params = tuple(self.diffusion_model.parameters())
+ optimizer_dynamics = torch.optim.AdamW(
+ params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
+ )
+ return optimizer_dynamics
+
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
+ # update params
+ optimizer.step(closure=optimizer_closure)
+
+ # manually warm up lr without a scheduler
+ if self.trainer.global_step < self.cfg.warmup_steps:
+ lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
+ for pg in optimizer.param_groups:
+ pg["lr"] = lr_scale * self.cfg.lr
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ xs, conditions, masks = self._preprocess_batch(batch)
+
+ rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item()
+ xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]])
+ conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]])
+ masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]])
+ noise_levels=self._generate_noise_levels(xs)
+ noise_levels[:rand_length] = 15 # stable_noise_levels
+ noise_levels[rand_length+1:] = 15 # stable_noise_levels
+
+ xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels)
+ loss = self.reweight_loss(loss, masks)
+
+ # log the loss
+ if batch_idx % 20 == 0:
+ self.log("training/loss", loss)
+
+ xs = self._unstack_and_unnormalize(xs)
+ xs_pred = self._unstack_and_unnormalize(xs_pred)
+
+ output_dict = {
+ "loss": loss,
+ "xs_pred": xs_pred,
+ "xs": xs,
+ }
+
+ return output_dict
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
+ xs, conditions, masks = self._preprocess_batch(batch)
+ n_frames, batch_size, *_ = xs.shape
+ xs_pred = []
+ curr_frame = 0
+
+ # context
+ n_context_frames = self.context_frames // self.frame_stack
+ xs_pred = xs[:n_context_frames].clone()
+ curr_frame += n_context_frames
+
+ if self.condtion_similar_length:
+ n_frames -= self.condtion_similar_length
+
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
+ while curr_frame < n_frames:
+ if self.chunk_size > 0:
+ horizon = min(n_frames - curr_frame, self.chunk_size)
+ else:
+ horizon = n_frames - curr_frame
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
+
+ chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device)
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
+ xs_pred = torch.cat([xs_pred, chunk], 0)
+
+ # sliding window: only input the last n_tokens frames
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
+
+ pbar.set_postfix(
+ {
+ "start": start_frame,
+ "end": curr_frame + horizon,
+ }
+ )
+
+ if self.condtion_similar_length:
+ xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0)
+
+ for m in range(scheduling_matrix.shape[0] - 1):
+
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
+ :, None
+ ].repeat(batch_size, axis=1)
+ to_noise_levels = np.concatenate(
+ (
+ np.zeros((curr_frame,), dtype=np.int64),
+ scheduling_matrix[m + 1],
+ )
+ )[
+ :, None
+ ].repeat(batch_size, axis=1)
+
+ if self.condtion_similar_length:
+ from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
+ to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
+
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
+
+ # update xs_pred by DDIM or DDPM sampling
+ # input frames within the sliding window
+
+ try:
+ input_condition = conditions[start_frame : curr_frame + horizon].clone()
+ except:
+ import pdb;pdb.set_trace()
+ if self.condtion_similar_length:
+ input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0)
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
+ xs_pred[start_frame:],
+ input_condition,
+ from_noise_levels[start_frame:],
+ to_noise_levels[start_frame:],
+ )
+
+ if self.condtion_similar_length:
+ xs_pred = xs_pred[:-self.condtion_similar_length]
+
+ curr_frame += horizon
+ pbar.update(horizon)
+
+ if self.condtion_similar_length:
+ xs = xs[:-self.condtion_similar_length]
+ # FIXME: loss
+ loss = F.mse_loss(xs_pred, xs, reduction="none")
+ loss = self.reweight_loss(loss, masks)
+ self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))
+
+ return loss
+
+ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
+ return self.validation_step(*args, **kwargs, namespace="test")
+
+ def test_epoch_end(self) -> None:
+ self.on_validation_epoch_end(namespace="test")
+
+ def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Generate noise levels for training.
+ """
+ num_frames, batch_size, *_ = xs.shape
+ match self.cfg.noise_level:
+ case "random_all": # entirely random noise levels
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
+ case "same":
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
+ noise_levels[1:] = noise_levels[0]
+
+ if masks is not None:
+ # for frames that are not available, treat as full noise
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
+
+ return noise_levels
+
+ def _generate_scheduling_matrix(self, horizon: int):
+ match self.cfg.scheduling_matrix:
+ case "pyramid":
+ return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale)
+ case "full_sequence":
+ return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1)
+ case "autoregressive":
+ return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps)
+ case "trapezoid":
+ return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale)
+
+ def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
+ height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
+ for m in range(height):
+ for t in range(horizon):
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
+
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
+
+ def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
+ height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale)
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
+ for m in range(height):
+ for t in range((horizon + 1) // 2):
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
+ scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
+
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
+
+ def reweight_loss(self, loss, weight=None):
+ # Note there is another part of loss reweighting (fused_snr) inside the Diffusion class!
+ loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
+ if weight is not None:
+ expand_dim = len(loss.shape) - len(weight.shape) - 1
+ weight = rearrange(
+ weight,
+ "(t fs) b ... -> t b fs ..." + " 1" * expand_dim,
+ fs=self.frame_stack,
+ )
+ loss = loss * weight
+
+ return loss.mean()
+
+ def _preprocess_batch(self, batch):
+ xs = batch[0]
+ batch_size, n_frames = xs.shape[:2]
+
+ if n_frames % self.frame_stack != 0:
+ raise ValueError("Number of frames must be divisible by frame stack size")
+ if self.context_frames % self.frame_stack != 0:
+ raise ValueError("Number of context frames must be divisible by frame stack size")
+
+ masks = torch.ones(n_frames, batch_size).to(xs.device)
+ n_frames = n_frames // self.frame_stack
+
+ if self.action_cond_dim:
+ conditions = batch[1]
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
+ conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
+
+ # f, _, _ = conditions.shape
+ # predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device)
+ # predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device)
+ # conditions[:f//2] = predefined_1
+ # conditions[f//2:] = predefined_2
+ else:
+ conditions = [None for _ in range(n_frames)]
+
+ xs = self._normalize_x(xs)
+ xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()
+
+ return xs, conditions, masks
+
+ def _normalize_x(self, xs):
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
+ mean = self.data_mean.reshape(shape)
+ std = self.data_std.reshape(shape)
+ return (xs - mean) / std
+
+ def _unnormalize_x(self, xs):
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
+ mean = self.data_mean.reshape(shape)
+ std = self.data_std.reshape(shape)
+ return xs * std + mean
+
+ def _unstack_and_unnormalize(self, xs):
+ xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
+ return self._unnormalize_x(xs)
diff --git a/algorithms/worldmem/df_video.py b/algorithms/worldmem/df_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfcea8eb4ee074d01f6f8ce63766018b0ace8214
--- /dev/null
+++ b/algorithms/worldmem/df_video.py
@@ -0,0 +1,908 @@
+import random
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from torchvision.transforms import InterpolationMode
+from PIL import Image
+from packaging import version as pver
+from einops import rearrange
+from tqdm import tqdm
+from omegaconf import DictConfig
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+from algorithms.common.metrics import (
+ LearnedPerceptualImagePatchSimilarity,
+)
+from utils.logging_utils import log_video, get_validation_metrics_for_videos
+from .df_base import DiffusionForcingBase
+from .models.vae import VAE_models
+from .models.diffusion import Diffusion
+from .models.pose_prediction import PosePredictionNet
+
+
+# Utility Functions
+def euler_to_rotation_matrix(pitch, yaw):
+ """
+ Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
+ Supports batch input.
+
+ Args:
+ pitch (torch.Tensor): Pitch angles in radians.
+ yaw (torch.Tensor): Yaw angles in radians.
+
+ Returns:
+ torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
+ """
+ cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
+ cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
+
+ R_pitch = torch.stack([
+ torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
+ torch.zeros_like(pitch), cos_pitch, -sin_pitch,
+ torch.zeros_like(pitch), sin_pitch, cos_pitch
+ ], dim=-1).reshape(-1, 3, 3)
+
+ R_yaw = torch.stack([
+ cos_yaw, torch.zeros_like(yaw), sin_yaw,
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
+ -sin_yaw, torch.zeros_like(yaw), cos_yaw
+ ], dim=-1).reshape(-1, 3, 3)
+
+ return torch.matmul(R_yaw, R_pitch)
+
+
+def euler_to_camera_to_world_matrix(pose):
+ """
+ Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
+ Supports both (5,) and (f, b, 5) shaped inputs.
+
+ Args:
+ pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
+
+ Returns:
+ torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
+ """
+
+ origin_dim = pose.ndim
+ if origin_dim == 1:
+ pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
+ elif origin_dim == 2:
+ pose = pose.unsqueeze(0)
+
+ x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
+ pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
+
+ # Compute rotation matrix (batch mode)
+ R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
+
+ # Create the 4x4 transformation matrix
+ eye = torch.eye(4, dtype=torch.float32, device=pose.device)
+ camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
+
+ # Assign rotation
+ camera_to_world[:, :3, :3] = R
+
+ # Assign translation
+ camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
+
+ # Reshape back to (f, b, 4, 4) if needed
+ if origin_dim == 3:
+ return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
+ elif origin_dim == 2:
+ return camera_to_world.view(pose.shape[0], 4, 4)
+ else:
+ return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
+
+def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
+ """
+ Check whether points are within a given 3D field of view (FOV)
+ with separately defined horizontal and vertical ranges.
+
+ The center view direction is specified by pitch and yaw (in degrees).
+
+ :param points: (N, B, 3) Sample point coordinates
+ :param center: (3,) Center coordinates of the FOV
+ :param center_pitch: Pitch angle of the center view (in degrees)
+ :param center_yaw: Yaw angle of the center view (in degrees)
+ :param fov_half_h: Horizontal half-FOV angle (in degrees)
+ :param fov_half_v: Vertical half-FOV angle (in degrees)
+ :return: Boolean tensor (N, B), indicating whether each point is inside the FOV
+ """
+ # Compute vectors relative to the center
+ vectors = points - center # shape (N, B, 3)
+ x = vectors[..., 0]
+ y = vectors[..., 1]
+ z = vectors[..., 2]
+
+ # Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
+ # and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
+ azimuth = torch.atan2(x, z) * (180 / math.pi)
+
+ # Compute vertical angle (pitch): measured with respect to the horizontal plane,
+ # resulting in a range of -90 to 90 degrees.
+ elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
+
+ # Compute the angular difference from the center view (handling circular angle wrap-around)
+ diff_azimuth = (azimuth - center_yaw).abs() % 360
+ diff_elevation = (elevation - center_pitch).abs() % 360
+
+ # Adjust values greater than 180 degrees to the shorter angular difference
+ diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
+ diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
+
+ # Check if both horizontal and vertical angles are within their respective FOV limits
+ return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
+
+def generate_points_in_sphere(n_points, radius):
+ # Sample three independent uniform distributions
+ samples_r = torch.rand(n_points) # For radius distribution
+ samples_phi = torch.rand(n_points) # For azimuthal angle phi
+ samples_u = torch.rand(n_points) # For polar angle theta
+
+ # Apply cube root to ensure uniform volumetric distribution
+ r = radius * torch.pow(samples_r, 1/3)
+ # Azimuthal angle phi uniformly distributed in [0, 2π]
+ phi = 2 * math.pi * samples_phi
+ # Convert u to theta to ensure cos(theta) is uniformly distributed
+ theta = torch.acos(1 - 2 * samples_u)
+
+ # Convert spherical coordinates to Cartesian coordinates
+ x = r * torch.sin(theta) * torch.cos(phi)
+ y = r * torch.sin(theta) * torch.sin(phi)
+ z = r * torch.cos(theta)
+
+ points = torch.stack((x, y, z), dim=1)
+ return points
+
+def tensor_max_with_number(tensor, number):
+ number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
+ result = torch.max(tensor, number_tensor)
+ return result
+
+def custom_meshgrid(*args):
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
+ return torch.meshgrid(*args)
+ else:
+ return torch.meshgrid(*args, indexing='ij')
+
+def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
+ """
+ Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
+
+ Args:
+ camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
+ f = number of frames,
+ b = batch size.
+
+ Returns:
+ torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
+ """
+ # Ensure input is a 4D tensor
+ assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
+ "Input must be of shape (f, b, 4, 4)"
+
+ # Extract the rotation (R) and translation (T) parts
+ R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
+ T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
+
+ # Initialize an identity matrix for the output
+ world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
+ world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
+
+ # Compute the rotation (transpose of R)
+ world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
+
+ # Compute the translation (-R^T * T)
+ world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
+
+ return world_to_camera.to(camera_to_world.dtype)
+
+def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
+
+ intrinsic = np.asarray([focal_length * image_width,
+ focal_length * image_height,
+ 0.5 * image_width,
+ 0.5 * image_height], dtype=np.float32)
+
+ c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
+ c2ws = rearrange(c2ws, "t b m n -> b t m n")
+
+ K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
+ plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
+ plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
+
+ return plucker_embedding
+
+
+def get_relative_pose(abs_c2ws, zero_first_frame_scale):
+ abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
+ target_cam_c2w = torch.tensor([
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ]).to(abs_c2ws.device).to(abs_c2ws.dtype)
+ abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
+ ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
+ ret_poses = torch.stack(ret_poses)
+ return ret_poses
+
+def ray_condition(K, c2w, H, W, device):
+ # c2w: B, V, 4, 4
+ # K: B, V, 4
+
+ B = K.shape[0]
+
+ j, i = custom_meshgrid(
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
+ )
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
+
+ zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
+ xs = -(i - cx) / fx * zs
+ ys = -(j - cy) / fy * zs
+
+ zs = zs.expand_as(ys)
+
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
+
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
+ rays_o = c2w[..., :3, 3] # B, V, 3
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
+ # c2w @ dirctions
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
+
+ return plucker
+
+def random_transform(tensor):
+ """
+ Apply the same random translation, rotation, and scaling to all frames in the batch.
+
+ Args:
+ tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
+
+ Returns:
+ torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
+ """
+ if tensor.ndim != 5:
+ raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
+
+ F, B, C, H, W = tensor.shape
+
+ # Generate random transformation parameters
+ max_translate = 0.2 # Translate up to 20% of width/height
+ max_rotate = 30 # Rotate up to 30 degrees
+ max_scale = 0.2 # Scale change by up to +/- 20%
+
+ translate_x = random.uniform(-max_translate, max_translate) * W
+ translate_y = random.uniform(-max_translate, max_translate) * H
+ rotate_angle = random.uniform(-max_rotate, max_rotate)
+ scale_factor = 1 + random.uniform(-max_scale, max_scale)
+
+ # Apply the same transformation to all frames and batches
+
+ tensor = tensor.reshape(F*B, C, H, W)
+ transformed_tensor = TF.affine(
+ tensor,
+ angle=rotate_angle,
+ translate=(translate_x, translate_y),
+ scale=scale_factor,
+ shear=(0, 0),
+ interpolation=InterpolationMode.BILINEAR,
+ fill=0
+ )
+
+ transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
+ return transformed_tensor
+
+def save_tensor_as_png(tensor, file_path):
+ """
+ Save a 3*H*W tensor as a PNG image.
+
+ Args:
+ tensor (torch.Tensor): Input tensor of shape (3, H, W).
+ file_path (str): Path to save the PNG file.
+ """
+ if tensor.ndim != 3 or tensor.shape[0] != 3:
+ raise ValueError("Input tensor must have shape (3, H, W)")
+
+ # Convert tensor to PIL Image
+ image = TF.to_pil_image(tensor)
+
+ # Save image
+ image.save(file_path)
+
+class WorldMemMinecraft(DiffusionForcingBase):
+ """
+ Video generation for MineCraft with memory.
+ """
+
+ def __init__(self, cfg: DictConfig):
+ """
+ Initialize the WorldMemMinecraft class with the given configuration.
+
+ Args:
+ cfg (DictConfig): Configuration object.
+ """
+ # self.metrics = cfg.metrics
+ self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
+ self.n_frames = cfg.n_frames
+ if hasattr(cfg, "n_tokens"):
+ self.n_tokens = cfg.n_tokens // cfg.frame_stack
+ self.condition_similar_length = cfg.condition_similar_length
+ self.pose_cond_dim = cfg.pose_cond_dim
+
+ self.use_plucker = cfg.use_plucker
+ self.relative_embedding = cfg.relative_embedding
+ self.cond_only_on_qk = cfg.cond_only_on_qk
+ self.use_reference_attention = cfg.use_reference_attention
+ self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder
+ self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
+ self.log_curve = getattr(cfg, "log_curve", False)
+ self.focal_length = cfg.focal_length
+ self.log_video = cfg.log_video
+ self.self_consistency_eval = getattr(cfg, "self_consistency_eval", False)
+
+ self.is_interactive = cfg.get("is_interactive", False)
+ if self.is_interactive:
+ self.frames = None
+ self.poses = None
+ self.memory_c2w = None
+ self.frame_idx = None
+
+ super().__init__(cfg)
+
+ def _build_model(self):
+
+ self.diffusion_model = Diffusion(
+ reference_length=self.condition_similar_length,
+ x_shape=self.x_stacked_shape,
+ action_cond_dim=self.action_cond_dim,
+ pose_cond_dim=self.pose_cond_dim,
+ is_causal=self.causal,
+ cfg=self.cfg.diffusion,
+ is_dit=True,
+ use_plucker=self.use_plucker,
+ relative_embedding=self.relative_embedding,
+ cond_only_on_qk=self.cond_only_on_qk,
+ use_reference_attention=self.use_reference_attention,
+ add_frame_timestep_embedder=self.add_frame_timestep_embedder,
+ ref_mode=self.ref_mode
+ )
+
+ self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
+ self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
+
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
+ self.vae = vae.eval()
+
+ self.pose_prediction_model = PosePredictionNet()
+
+ def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
+ """
+ Generate noise levels for training.
+ """
+ num_frames, batch_size, *_ = xs.shape
+ match self.cfg.noise_level:
+ case "random_all": # entirely random noise levels
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
+ case "same":
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
+ noise_levels[1:] = noise_levels[0]
+
+ if masks is not None:
+ # for frames that are not available, treat as full noise
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
+
+ return noise_levels
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ """
+ Perform a single training step.
+
+ This function processes the input batch,
+ encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
+
+ Args:
+ batch: Input batch of data containing frames, conditions, poses, etc.
+ batch_idx: Index of the current batch.
+
+ Returns:
+ dict: A dictionary containing the training loss.
+ """
+ xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
+
+ if self.use_plucker:
+ if self.relative_embedding:
+ input_pose_condition = []
+ frame_idx_list = []
+ for i in range(self.n_frames):
+ input_pose_condition.append(
+ convert_to_plucker(
+ torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.condition_similar_length:]]).clone(),
+ 0,
+ focal_length=self.focal_length,
+ image_height=xs.shape[-2],image_width=xs.shape[-1]
+ ).to(xs.dtype)
+ )
+ frame_idx_list.append(
+ torch.cat([
+ frame_idx[i:i + 1] - frame_idx[i:i + 1],
+ frame_idx[-self.condition_similar_length:] - frame_idx[i:i + 1]
+ ]).clone()
+ )
+ input_pose_condition = torch.cat(input_pose_condition)
+ frame_idx_list = torch.cat(frame_idx_list)
+ else:
+ input_pose_condition = convert_to_plucker(
+ c2w_mat, 0, focal_length=self.focal_length
+ ).to(xs.dtype)
+ frame_idx_list = frame_idx
+ else:
+ input_pose_condition = pose_conditions.to(xs.dtype)
+ frame_idx_list = None
+
+ xs = self.encode(xs)
+
+ noise_levels = self._generate_noise_levels(xs)
+
+ if self.condition_similar_length:
+ noise_levels[-self.condition_similar_length:] = self.diffusion_model.stabilization_level
+ conditions[-self.condition_similar_length:] *= 0
+
+ _, loss = self.diffusion_model(
+ xs,
+ conditions,
+ input_pose_condition,
+ noise_levels=noise_levels,
+ reference_length=self.condition_similar_length,
+ frame_idx=frame_idx_list
+ )
+
+ if self.condition_similar_length:
+ loss = loss[:-self.condition_similar_length]
+
+ loss = self.reweight_loss(loss, None)
+
+ if batch_idx % 20 == 0:
+ self.log("training/loss", loss.cpu())
+
+ return {"loss": loss}
+
+
+ def on_validation_epoch_end(self, namespace="validation") -> None:
+ if not self.validation_step_outputs:
+ return
+
+ xs_pred = []
+ xs = []
+ for pred, gt in self.validation_step_outputs:
+ xs_pred.append(pred)
+ xs.append(gt)
+
+ xs_pred = torch.cat(xs_pred, 1)
+ if gt is not None:
+ xs = torch.cat(xs, 1)
+ else:
+ xs = None
+
+ if self.logger and self.log_video:
+ log_video(
+ xs_pred,
+ xs,
+ step=None if namespace == "test" else self.global_step,
+ namespace=namespace + "_vis",
+ context_frames=self.context_frames,
+ logger=self.logger.experiment,
+ )
+
+ if xs is not None:
+ metric_dict = get_validation_metrics_for_videos(
+ xs_pred, xs,
+ lpips_model=self.validation_lpips_model)
+
+ self.log_dict(
+ {"mse": metric_dict['mse'],
+ "psnr": metric_dict['psnr'],
+ "lpips": metric_dict['lpips']},
+ sync_dist=True
+ )
+
+ if self.log_curve:
+ psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
+ frames = list(range(len(psnr_values)))
+ line_plot = wandb.plot.line_series(
+ xs = frames,
+ ys = [psnr_values],
+ keys = ["PSNR"],
+ title = "Frame-wise PSNR",
+ xname = "Frame index"
+ )
+
+ self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
+
+ elif self.self_consistency_eval:
+ metric_dict = get_validation_metrics_for_videos(
+ xs_pred[:1],
+ xs_pred[-1:],
+ lpips_model=self.validation_lpips_model,
+ )
+ self.log_dict(
+ {"lpips": metric_dict['lpips'],
+ "mse": metric_dict['mse'],
+ "psnr": metric_dict['psnr']},
+ sync_dist=True
+ )
+
+ self.validation_step_outputs.clear()
+
+ def _preprocess_batch(self, batch):
+
+ xs, conditions, pose_conditions, frame_index = batch
+
+ if self.action_cond_dim:
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
+ conditions = rearrange(conditions, "b t d -> t b d").contiguous()
+ else:
+ raise NotImplementedError("Only support external cond.")
+
+ pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
+ c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
+ xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
+ frame_index = rearrange(frame_index, "b t -> t b").contiguous()
+
+ return xs, conditions, pose_conditions, c2w_mat, frame_index
+
+ def encode(self, x):
+ # vae encoding
+ T = x.shape[0]
+ H, W = x.shape[-2:]
+ scaling_factor = 0.07843137255
+
+ x = rearrange(x, "t b c h w -> (t b) c h w")
+ with torch.no_grad():
+ x = self.vae.encode(x * 2 - 1).mean * scaling_factor
+ x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
+ return x
+
+ def decode(self, x):
+ total_frames = x.shape[0]
+ scaling_factor = 0.07843137255
+ x = rearrange(x, "t b c h w -> (t b) (h w) c")
+ with torch.no_grad():
+ x = (self.vae.decode(x / scaling_factor) + 1) / 2
+ x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
+ return x
+
+ def _generate_condition_indices(self, curr_frame, condition_similar_length, xs_pred, pose_conditions, frame_idx):
+ """
+ Generate indices for condition similarity based on the current frame and pose conditions.
+ """
+ if curr_frame < condition_similar_length:
+ random_idx = [i for i in range(curr_frame)] + [0] * (condition_similar_length - curr_frame)
+ random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
+ else:
+ # Generate points in a sphere and filter based on field of view
+ num_samples = 10000
+ radius = 30
+ points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
+ points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
+ points += pose_conditions[curr_frame, :, :3][None]
+ fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
+ fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
+ in_fov1 = is_inside_fov_3d_hv(
+ points, pose_conditions[curr_frame, :, :3],
+ pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
+ fov_half_h, fov_half_v
+ )
+
+ # Compute overlap ratios and select indices
+ in_fov_list = torch.stack([
+ is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
+ for pc in pose_conditions[:curr_frame]
+ ])
+ random_idx = []
+ for _ in range(condition_similar_length):
+ overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
+
+ # if curr_frame == 54:
+ # import pdb;pdb.set_trace()
+ confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
+
+ if len(random_idx) > 0:
+ confidence[torch.cat(random_idx)] = -1e10
+ _, r_idx = torch.topk(confidence, k=1, dim=0)
+ random_idx.append(r_idx[0])
+
+ occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
+
+ in_fov1 = in_fov1 & ~occupied_mask
+
+ # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
+ # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
+ # cos_sim = cos_sim.mean((-2,-1))
+
+ # mask_sim = cos_sim>0.9
+ # in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
+
+ random_idx = torch.stack(random_idx).cpu()
+
+ print(random_idx)
+
+ return random_idx
+
+ def _prepare_conditions(self,
+ start_frame, curr_frame, horizon, conditions,
+ pose_conditions, c2w_mat, frame_idx, random_idx,
+ image_width, image_height):
+ """
+ Prepare input conditions and pose conditions for sampling.
+ """
+
+ padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
+ input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
+
+ batch_size = conditions.shape[1]
+
+ if self.use_plucker:
+ if self.relative_embedding:
+ frame_idx_list = []
+ input_pose_condition = []
+ for i in range(start_frame, curr_frame + horizon):
+ input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
+ image_width=image_width, image_height=image_height).to(conditions.dtype))
+ frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
+ input_pose_condition = torch.cat(input_pose_condition)
+ frame_idx_list = torch.cat(frame_idx_list)
+
+ else:
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
+ input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
+ frame_idx_list = None
+ else:
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
+ frame_idx_list = None
+
+ return input_condition, input_pose_condition, frame_idx_list
+
+ def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, condition_similar_length):
+ """
+ Prepare noise levels for the current sampling step.
+ """
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
+ to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
+ if condition_similar_length:
+ from_noise_levels = np.concatenate([from_noise_levels, np.zeros((condition_similar_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
+ to_noise_levels = np.concatenate([to_noise_levels, np.zeros((condition_similar_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
+ return from_noise_levels, to_noise_levels
+
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
+ """
+ Perform a single validation step.
+
+ This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
+ and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
+
+ Args:
+ batch: Input batch of data containing frames, conditions, poses, etc.
+ batch_idx: Index of the current batch.
+ namespace: Namespace for logging (default: "validation").
+
+ Returns:
+ None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
+ """
+ # Preprocess the input batch
+ condition_similar_length = self.condition_similar_length
+ xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
+
+ # Encode frames in chunks if necessary
+ total_frame = xs_raw.shape[0]
+ if total_frame > 10:
+ xs = torch.cat([
+ self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
+ for i in range(10)
+ ])
+ else:
+ xs = self.encode(xs_raw).cpu()
+
+ n_frames, batch_size, *_ = xs.shape
+ curr_frame = 0
+
+ # Initialize context frames
+ n_context_frames = self.context_frames // self.frame_stack
+ xs_pred = xs[:n_context_frames].clone()
+ curr_frame += n_context_frames
+
+ # Progress bar for sampling
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
+
+ while curr_frame < n_frames:
+ # Determine the horizon for the current chunk
+ horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
+ assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
+
+ # Generate scheduling matrix and initialize noise
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
+ xs_pred = torch.cat([xs_pred, chunk], 0)
+
+ # Sliding window: only input the last `n_tokens` frames
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
+ pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
+
+ # Handle condition similarity logic
+ if condition_similar_length:
+ random_idx = self._generate_condition_indices(
+ curr_frame, condition_similar_length, xs_pred, pose_conditions, frame_idx
+ )
+
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
+
+ # Prepare input conditions and pose conditions
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
+ start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
+ image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
+ )
+
+ # Perform sampling for each step in the scheduling matrix
+ for m in range(scheduling_matrix.shape[0] - 1):
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
+ scheduling_matrix, m, curr_frame, batch_size, condition_similar_length
+ )
+
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
+ xs_pred[start_frame:].to(input_condition.device),
+ input_condition,
+ input_pose_condition,
+ from_noise_levels[start_frame:],
+ to_noise_levels[start_frame:],
+ current_frame=curr_frame,
+ mode="validation",
+ reference_length=condition_similar_length,
+ frame_idx=frame_idx_list
+ ).cpu()
+
+ # Remove condition similarity frames if applicable
+ if condition_similar_length:
+ xs_pred = xs_pred[:-condition_similar_length]
+
+ curr_frame += horizon
+ pbar.update(horizon)
+
+ # Decode predictions and ground truth
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
+ xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
+
+ # Store results for evaluation
+ self.validation_step_outputs.append((xs_pred, xs_decode))
+ return
+
+ @torch.no_grad()
+ def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device):
+ condition_similar_length = self.condition_similar_length
+
+ if self.frames is None:
+ first_frame_encode = self.encode(first_frame[None, None].to(device))
+ self.frames = first_frame_encode.cpu()
+ self.actions = curr_actions[None, None].to(device)
+ self.poses = first_pose[None, None].to(device)
+ new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
+ self.memory_c2w = new_c2w_mat[None, None].to(device)
+ self.frame_idx = torch.tensor([[context_frames_idx]]).to(device)
+ return first_frame
+ else:
+ last_frame = self.frames[-1].clone()
+ last_pose_condition = self.poses[-1].clone()
+ last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
+ new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
+
+ new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
+ new_pose_condition = last_pose_condition + new_pose_condition_offset
+ new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
+ new_pose_condition[:,3:] %= 360
+ print(new_pose_condition)
+ self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
+ self.poses = torch.cat([self.poses, new_pose_condition[None].to(device)])
+ new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
+ self.memory_c2w = torch.cat([self.memory_c2w, new_c2w_mat[None].to(device)])
+ self.frame_idx = torch.cat([self.frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
+
+ conditions = self.actions.clone()
+ pose_conditions = self.poses.clone()
+ c2w_mat = self.memory_c2w .clone()
+ frame_idx = self.frame_idx.clone()
+
+
+ curr_frame = 0
+ horizon = 1
+ batch_size = 1
+ n_frames = curr_frame + horizon
+ # context
+ n_context_frames = context_frames_idx // self.frame_stack
+ xs_pred = self.frames[:n_context_frames].clone()
+ curr_frame += n_context_frames
+
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
+
+ # generation on frame
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
+
+ xs_pred = torch.cat([xs_pred, chunk], 0)
+
+ # sliding window: only input the last n_tokens frames
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
+
+ pbar.set_postfix(
+ {
+ "start": start_frame,
+ "end": curr_frame + horizon,
+ }
+ )
+
+ # Handle condition similarity logic
+ if condition_similar_length:
+ random_idx = self._generate_condition_indices(
+ curr_frame, condition_similar_length, xs_pred, pose_conditions, frame_idx
+ )
+
+ # random_idx = np.unique(random_idx)[:, None]
+ # condition_similar_length = len(random_idx)
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
+
+ # Prepare input conditions and pose conditions
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
+ start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
+ image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
+ )
+
+ # Perform sampling for each step in the scheduling matrix
+ for m in range(scheduling_matrix.shape[0] - 1):
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
+ scheduling_matrix, m, curr_frame, batch_size, condition_similar_length
+ )
+
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
+ xs_pred[start_frame:].to(input_condition.device),
+ input_condition,
+ input_pose_condition,
+ from_noise_levels[start_frame:],
+ to_noise_levels[start_frame:],
+ current_frame=curr_frame,
+ mode="validation",
+ reference_length=condition_similar_length,
+ frame_idx=frame_idx_list
+ ).cpu()
+
+
+ if condition_similar_length:
+ xs_pred = xs_pred[:-condition_similar_length]
+
+ curr_frame += horizon
+ pbar.update(horizon)
+
+ self.frames = torch.cat([self.frames, xs_pred[n_context_frames:]])
+
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
+ return xs_pred[-1,0]
+
+
+ def reset(self):
+ self.frames = None
+ self.poses = None
+ self.memory_c2w = None
+ self.frame_idx = None
\ No newline at end of file
diff --git a/algorithms/worldmem/models/__pycache__/attention.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..940763ef8e856930784f1dc280a5e4fb20992231
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/attention.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/cameractrl_module.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/cameractrl_module.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90403088b8b5fe52434a75b8c14b78c27ffa33b7
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/cameractrl_module.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/diffusion.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74c086bc934222a5096c31284a59ee04296539f4
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/dit.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..949e56cfacce694fdc8d347b30e48efb310c01da
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/dit.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/my_rotary_embedding_torch.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/my_rotary_embedding_torch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0194d597fe82e538c71e6cce24d7f7504706abf2
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/my_rotary_embedding_torch.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/pose_prediction.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/pose_prediction.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..419b24794dd196f1e13d9fe6f372467d8e5f8a95
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/pose_prediction.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/rotary_embedding_torch.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/rotary_embedding_torch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6b0a9a831d829f3f0d1bb4d4600c024b1fad80d
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/rotary_embedding_torch.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/utils.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..096138214a43941a8ce3324ea0ab2a44fd87d5c6
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/__pycache__/vae.cpython-310.pyc b/algorithms/worldmem/models/__pycache__/vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0e4b44825c57bf397fb40b0211ca0f176c63495
Binary files /dev/null and b/algorithms/worldmem/models/__pycache__/vae.cpython-310.pyc differ
diff --git a/algorithms/worldmem/models/attention.py b/algorithms/worldmem/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..975eb190707a26757b89f3958f9150a707aa7430
--- /dev/null
+++ b/algorithms/worldmem/models/attention.py
@@ -0,0 +1,351 @@
+"""
+Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
+"""
+
+from typing import Optional
+from collections import namedtuple
+import torch
+from torch import nn
+from torch.nn import functional as F
+from einops import rearrange
+from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
+import numpy as np
+
+def create_attention_bias(f1, f2, device=None, dtype=torch.float32):
+ f = f1 + f2
+ mask = torch.zeros((f, f), dtype=dtype, device=device)
+ if f1 > 0:
+ mask[:f1, :f1] = float('-inf')
+ if f2 > 0:
+ mask[f1:, f1:] = float('-inf')
+ return mask
+
+class TemporalAxialAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ dim_head: int,
+ reference_length: int,
+ rotary_emb: RotaryEmbedding,
+ is_causal: bool = True,
+ is_temporal_independent: bool = False,
+ use_domain_adapter = False
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
+
+ self.use_domain_adapter = use_domain_adapter
+ if self.use_domain_adapter:
+ lora_rank = 8
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
+
+ self.to_out = nn.Linear(self.inner_dim, dim)
+
+ self.rotary_emb = rotary_emb
+ self.is_causal = is_causal
+ self.is_temporal_independent = is_temporal_independent
+
+ self.reference_length = reference_length
+
+ def forward(self, x: torch.Tensor):
+ B, T, H, W, D = x.shape
+
+ # if T>=9:
+ # try:
+ # # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1)
+ # x = torch.cat([x[:,16-T:17-T],x], dim=1)
+ # except:
+ # import pdb;pdb.set_trace()
+ # print("="*50)
+ # print(x.shape)
+
+ B, T, H, W, D = x.shape
+
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+
+ if self.use_domain_adapter:
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
+ q = q+q_lora
+ k = k+k_lora
+ v = v+v_lora
+
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+
+ q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
+ k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
+
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
+
+ if self.is_temporal_independent:
+ attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device)
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
+ attn_bias[range(T), range(T)] = 0
+ elif self.is_causal:
+ attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1)
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
+ attn_bias[(T-self.reference_length):] = float('-inf')
+ attn_bias[range(T), range(T)] = 0
+ else:
+ attn_bias = None
+
+ try:
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
+ except:
+ import pdb;pdb.set_trace()
+
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
+ x = x.to(q.dtype)
+
+ # linear proj
+ x = self.to_out(x)
+
+ # if T>=10:
+ # try:
+ # # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1)
+ # x = x[:,1:]
+ # except:
+ # import pdb;pdb.set_trace()
+ # print(x.shape)
+ return x
+
+class SpatialAxialAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ dim_head: int,
+ rotary_emb: RotaryEmbedding,
+ use_domain_adapter = False
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
+ self.use_domain_adapter = use_domain_adapter
+ if self.use_domain_adapter:
+ lora_rank = 8
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
+
+ self.to_out = nn.Linear(self.inner_dim, dim)
+
+ self.rotary_emb = rotary_emb
+
+ def forward(self, x: torch.Tensor):
+ B, T, H, W, D = x.shape
+
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+
+ if self.use_domain_adapter:
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
+ q = q+q_lora
+ k = k+k_lora
+ v = v+v_lora
+
+ q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
+ k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
+ v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
+
+ freqs = self.rotary_emb.get_axial_freqs(H, W)
+ q = apply_rotary_emb(freqs, q)
+ k = apply_rotary_emb(freqs, k)
+
+ # prepare for attn
+ q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
+ k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
+ v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
+
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
+
+ x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
+ x = x.to(q.dtype)
+
+ # linear proj
+ x = self.to_out(x)
+ return x
+
+class MemTemporalAxialAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ dim_head: int,
+ rotary_emb: RotaryEmbedding,
+ is_causal: bool = True,
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
+ self.to_out = nn.Linear(self.inner_dim, dim)
+
+ self.rotary_emb = rotary_emb
+ self.is_causal = is_causal
+
+ self.reference_length = 3
+
+ def forward(self, x: torch.Tensor):
+ B, T, H, W, D = x.shape
+
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+
+
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
+
+
+
+ # q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
+ # k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
+
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
+
+ # if T == 21000:
+ # # 手动计算缩放点积分数
+ # _, _, _, d_k = q.shape
+ # scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k)
+
+ # # 计算注意力图 (Attention Map)
+ # attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k)
+ # b_, h_, n_, m_ = attention_map.shape
+ # attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_)
+ # attention_map = attention_map.mean(3)
+
+ # attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
+ # T_origin = T - self.reference_length
+ # attn_bias[:T_origin, T_origin:] = 1
+ # attn_bias[range(T), range(T)] = 1
+
+ # attention_map = attention_map * attn_bias
+
+ # # print 注意力图
+ # import matplotlib.pyplot as plt
+ # fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小
+
+ # # 遍历3*3维度
+ # for i in range(21000):
+ # for j in range(21000):
+ # # 取出第(i, j)个子图像
+ # img = attention_map[0, :, :, i, j].cpu().numpy()
+ # axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap
+ # axes[i, j].axis('off') # 隐藏坐标轴
+
+ # # 调整子图间距
+ # plt.tight_layout()
+ # plt.savefig('attention_map.png')
+ # import pdb; pdb.set_trace()
+ # plt.close()
+
+ attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
+ attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf'))
+ T_origin = T - self.reference_length
+ attn_bias[:T_origin, T_origin:] = 0
+ attn_bias[range(T), range(T)] = 0
+
+ # if T==121000:
+ # import pdb;pdb.set_trace()
+
+ try:
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
+ except:
+ import pdb;pdb.set_trace()
+
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
+ x = x.to(q.dtype)
+
+ # linear proj
+ x = self.to_out(x)
+ return x
+
+class MemFullAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ dim_head: int,
+ reference_length: int,
+ rotary_emb: RotaryEmbedding,
+ is_causal: bool = True
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
+ self.to_out = nn.Linear(self.inner_dim, dim)
+
+ self.rotary_emb = rotary_emb
+ self.is_causal = is_causal
+
+ self.reference_length = reference_length
+
+ self.store = None
+
+ def forward(self, x: torch.Tensor, relative_embedding=False,
+ extra_condition=None,
+ cond_only_on_qk=False,
+ reference_length=None):
+
+ B, T, H, W, D = x.shape
+
+ if cond_only_on_qk:
+ q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1)
+ _, _, v = self.to_qkv(x).chunk(3, dim=-1)
+ else:
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+
+ if relative_embedding:
+ length = reference_length+1
+ n_frames = T // length
+ x = x.reshape(B, n_frames, length, H, W, D)
+
+ x_list = []
+
+ for i in range(n_frames):
+ if i == n_frames-1:
+ q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ else:
+ q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+
+ q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i))
+ x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i)
+ x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
+ x_i = x_i.to(q.dtype)
+ x_list.append(x_i)
+
+ x = torch.cat(x_list, dim=1)
+
+
+ else:
+ T_ = T - reference_length
+ q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+ v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
+
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v)
+ x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
+ x = x.to(q.dtype)
+
+ # linear proj
+ x = self.to_out(x)
+
+ return x
diff --git a/algorithms/worldmem/models/cameractrl_module.py b/algorithms/worldmem/models/cameractrl_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bac51c45c6bcc11d46e453c7d2a8c4141337a01
--- /dev/null
+++ b/algorithms/worldmem/models/cameractrl_module.py
@@ -0,0 +1,12 @@
+import torch.nn as nn
+class SimpleCameraPoseEncoder(nn.Module):
+ def __init__(self, c_in, c_out, hidden_dim=128):
+ super(SimpleCameraPoseEncoder, self).__init__()
+ self.model = nn.Sequential(
+ nn.Linear(c_in, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, c_out)
+ )
+ def forward(self, x):
+ return self.model(x)
+
diff --git a/algorithms/worldmem/models/diffusion.py b/algorithms/worldmem/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccaa251491c70d54f634985349c73cf4230d9ecc
--- /dev/null
+++ b/algorithms/worldmem/models/diffusion.py
@@ -0,0 +1,520 @@
+from typing import Optional, Callable
+from collections import namedtuple
+from omegaconf import DictConfig
+import torch
+from torch import nn
+from torch.nn import functional as F
+from einops import rearrange
+from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
+from .dit import DiT_models
+
+ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
+
+
+class Diffusion(nn.Module):
+ # Special thanks to lucidrains for the implementation of the base Diffusion model
+ # https://github.com/lucidrains/denoising-diffusion-pytorch
+
+ def __init__(
+ self,
+ x_shape: torch.Size,
+ reference_length: int,
+ action_cond_dim: int,
+ pose_cond_dim,
+ is_causal: bool,
+ cfg: DictConfig,
+ is_dit: bool=False,
+ use_plucker=False,
+ relative_embedding=False,
+ cond_only_on_qk=False,
+ use_reference_attention=False,
+ add_frame_timestep_embedder=False,
+ ref_mode='sequential'
+ ):
+ super().__init__()
+ self.cfg = cfg
+
+ self.x_shape = x_shape
+ self.action_cond_dim = action_cond_dim
+ self.timesteps = cfg.timesteps
+ self.sampling_timesteps = cfg.sampling_timesteps
+ self.beta_schedule = cfg.beta_schedule
+ self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
+ self.objective = cfg.objective
+ self.use_fused_snr = cfg.use_fused_snr
+ self.snr_clip = cfg.snr_clip
+ self.cum_snr_decay = cfg.cum_snr_decay
+ self.ddim_sampling_eta = cfg.ddim_sampling_eta
+ self.clip_noise = cfg.clip_noise
+ self.arch = cfg.architecture
+ self.stabilization_level = cfg.stabilization_level
+ self.is_causal = is_causal
+ self.is_dit = is_dit
+ self.reference_length = reference_length
+ self.pose_cond_dim = pose_cond_dim
+ self.use_plucker = use_plucker
+ self.relative_embedding = relative_embedding
+ self.cond_only_on_qk = cond_only_on_qk
+ self.use_reference_attention = use_reference_attention
+ self.add_frame_timestep_embedder = add_frame_timestep_embedder
+ self.ref_mode = ref_mode
+
+ self._build_model()
+ self._build_buffer()
+
+ def _build_model(self):
+ x_channel = self.x_shape[0]
+ if self.is_dit:
+ self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
+ pose_cond_dim=self.pose_cond_dim, reference_length=self.reference_length,
+ use_plucker=self.use_plucker,
+ relative_embedding=self.relative_embedding,
+ cond_only_on_qk=self.cond_only_on_qk,
+ use_reference_attention=self.use_reference_attention,
+ add_frame_timestep_embedder=self.add_frame_timestep_embedder,
+ ref_mode=self.ref_mode)
+ else:
+ raise NotImplementedError
+
+ def _build_buffer(self):
+ if self.beta_schedule == "linear":
+ beta_schedule_fn = linear_beta_schedule
+ elif self.beta_schedule == "cosine":
+ beta_schedule_fn = cosine_beta_schedule
+ elif self.beta_schedule == "sigmoid":
+ beta_schedule_fn = sigmoid_beta_schedule
+ else:
+ raise ValueError(f"unknown beta schedule {self.beta_schedule}")
+
+ betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
+
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
+
+ # sampling related parameters
+ assert self.sampling_timesteps <= self.timesteps
+ self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
+
+ # helper function to register buffer from float64 to float32
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
+
+ register_buffer("betas", betas)
+ register_buffer("alphas_cumprod", alphas_cumprod)
+ register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+
+ register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
+ register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
+ register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
+ register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
+ register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+
+ register_buffer("posterior_variance", posterior_variance)
+
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+
+ register_buffer(
+ "posterior_log_variance_clipped",
+ torch.log(posterior_variance.clamp(min=1e-20)),
+ )
+ register_buffer(
+ "posterior_mean_coef1",
+ betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
+ )
+ register_buffer(
+ "posterior_mean_coef2",
+ (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
+ )
+
+ # calculate p2 reweighting
+
+ # register_buffer(
+ # "p2_loss_weight",
+ # (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
+ # ** -self.p2_loss_weight_gamma,
+ # )
+
+ # derive loss weight
+ # https://arxiv.org/abs/2303.09556
+ # snr: signal noise ratio
+ snr = alphas_cumprod / (1 - alphas_cumprod)
+ clipped_snr = snr.clone()
+ clipped_snr.clamp_(max=self.snr_clip)
+
+ register_buffer("clipped_snr", clipped_snr)
+ register_buffer("snr", snr)
+
+ def add_shape_channels(self, x):
+ return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
+
+ def model_predictions(self, x, t, action_cond=None, current_frame=None,
+ pose_cond=None, mode="training", reference_length=None, frame_idx=None):
+ x = x.permute(1,0,2,3,4)
+ action_cond = action_cond.permute(1,0,2)
+ if pose_cond is not None and pose_cond[0] is not None:
+ try:
+ pose_cond = pose_cond.permute(1,0,2)
+ except:
+ pass
+ t = t.permute(1,0)
+ model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
+ mode=mode, reference_length=reference_length, frame_idx=frame_idx)
+ model_output = model_output.permute(1,0,2,3,4)
+ x = x.permute(1,0,2,3,4)
+ t = t.permute(1,0)
+
+ if self.objective == "pred_noise":
+ pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
+
+ elif self.objective == "pred_x0":
+ x_start = model_output
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
+
+ elif self.objective == "pred_v":
+ v = model_output
+ x_start = self.predict_start_from_v(x, t, v)
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
+
+
+ return ModelPrediction(pred_noise, x_start, model_output)
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_noise_from_start(self, x_t, t, x0):
+ return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
+ )
+
+ def predict_v(self, x_start, t, noise):
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
+ )
+
+ def predict_start_from_v(self, x_t, t, v):
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def q_mean_variance(self, x_start, t):
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def q_sample(self, x_start, t, noise=None):
+ if noise is None:
+ noise = torch.randn_like(x_start)
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def p_mean_variance(self, x, t, action_cond=None, pose_cond=None, reference_length=None):
+ model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
+ pose_cond=pose_cond, reference_length=reference_length,
+ frame_idx=frame_idx)
+ x_start = model_pred.pred_x_start
+ return self.q_posterior(x_start=x_start, x_t=x, t=t)
+
+ def compute_loss_weights(self, noise_levels: torch.Tensor):
+
+ snr = self.snr[noise_levels]
+ clipped_snr = self.clipped_snr[noise_levels]
+ normalized_clipped_snr = clipped_snr / self.snr_clip
+ normalized_snr = snr / self.snr_clip
+
+ if not self.use_fused_snr:
+ # min SNR reweighting
+ match self.objective:
+ case "pred_noise":
+ return clipped_snr / snr
+ case "pred_x0":
+ return clipped_snr
+ case "pred_v":
+ return clipped_snr / (snr + 1)
+
+ cum_snr = torch.zeros_like(normalized_snr)
+ for t in range(0, noise_levels.shape[0]):
+ if t == 0:
+ cum_snr[t] = normalized_clipped_snr[t]
+ else:
+ cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
+
+ cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
+ clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
+ fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
+
+ match self.objective:
+ case "pred_noise":
+ return clipped_fused_snr / fused_snr
+ case "pred_x0":
+ return clipped_fused_snr * self.snr_clip
+ case "pred_v":
+ return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
+ case _:
+ raise ValueError(f"unknown objective {self.objective}")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ action_cond: Optional[torch.Tensor],
+ pose_cond,
+ noise_levels: torch.Tensor,
+ reference_length,
+ frame_idx=None
+ ):
+ noise = torch.randn_like(x)
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
+
+ noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
+
+ model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
+ pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx)
+
+ pred = model_pred.model_out
+ x_pred = model_pred.pred_x_start
+
+ if self.objective == "pred_noise":
+ target = noise
+ elif self.objective == "pred_x0":
+ target = x
+ elif self.objective == "pred_v":
+ target = self.predict_v(x, noise_levels, noise)
+ else:
+ raise ValueError(f"unknown objective {self.objective}")
+
+ # 训练的时候每个frame随便给噪声
+ loss = F.mse_loss(pred, target.detach(), reduction="none")
+ loss_weight = self.compute_loss_weights(noise_levels)
+
+ loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
+
+ loss = loss * loss_weight
+
+ return x_pred, loss
+
+ def sample_step(
+ self,
+ x: torch.Tensor,
+ action_cond: Optional[torch.Tensor],
+ pose_cond,
+ curr_noise_level: torch.Tensor,
+ next_noise_level: torch.Tensor,
+ guidance_fn: Optional[Callable] = None,
+ current_frame=None,
+ mode="training",
+ reference_length=None,
+ frame_idx=None
+ ):
+ real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
+
+ # convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
+ curr_noise_level = real_steps[curr_noise_level]
+ next_noise_level = real_steps[next_noise_level]
+
+ if self.is_ddim_sampling:
+ return self.ddim_sample_step(
+ x=x,
+ action_cond=action_cond,
+ pose_cond=pose_cond,
+ curr_noise_level=curr_noise_level,
+ next_noise_level=next_noise_level,
+ guidance_fn=guidance_fn,
+ current_frame=current_frame,
+ mode=mode,
+ reference_length=reference_length,
+ frame_idx=frame_idx
+ )
+
+ # FIXME: temporary code for checking ddpm sampling
+ assert torch.all(
+ (curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
+ ), "Wrong noise level given for ddpm sampling."
+
+ assert (
+ self.sampling_timesteps == self.timesteps
+ ), "sampling_timesteps should be equal to timesteps for ddpm sampling."
+
+ return self.ddpm_sample_step(
+ x=x,
+ action_cond=action_cond,
+ pose_cond=pose_cond,
+ curr_noise_level=curr_noise_level,
+ guidance_fn=guidance_fn,
+ reference_length=reference_length,
+ frame_idx=frame_idx
+ )
+
+ def ddpm_sample_step(
+ self,
+ x: torch.Tensor,
+ action_cond: Optional[torch.Tensor],
+ pose_cond,
+ curr_noise_level: torch.Tensor,
+ guidance_fn: Optional[Callable] = None,
+ reference_length=None,
+ frame_idx=None,
+ ):
+ clipped_curr_noise_level = torch.where(
+ curr_noise_level < 0,
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
+ curr_noise_level,
+ )
+
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
+ orig_x = x.clone().detach()
+ scaled_context = self.q_sample(
+ x,
+ clipped_curr_noise_level,
+ noise=torch.zeros_like(x),
+ )
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
+
+ if guidance_fn is not None:
+ raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
+
+ else:
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ x=x,
+ t=clipped_curr_noise_level,
+ action_cond=action_cond,
+ pose_cond=pose_cond,
+ reference_length=reference_length,
+ frame_idx=frame_idx
+ )
+
+ noise = torch.where(
+ self.add_shape_channels(clipped_curr_noise_level > 0),
+ torch.randn_like(x),
+ 0,
+ )
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
+ x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
+
+ # only update frames where the noise level decreases
+ return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
+
+ def ddim_sample_step(
+ self,
+ x: torch.Tensor,
+ action_cond: Optional[torch.Tensor],
+ pose_cond,
+ curr_noise_level: torch.Tensor,
+ next_noise_level: torch.Tensor,
+ guidance_fn: Optional[Callable] = None,
+ current_frame=None,
+ mode="training",
+ reference_length=None,
+ frame_idx=None
+ ):
+ # convert noise level -1 to self.stabilization_level - 1
+ clipped_curr_noise_level = torch.where(
+ curr_noise_level < 0,
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
+ curr_noise_level,
+ )
+
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
+ orig_x = x.clone().detach()
+ scaled_context = self.q_sample(
+ x,
+ clipped_curr_noise_level,
+ noise=torch.zeros_like(x),
+ )
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
+
+ alpha = self.alphas_cumprod[clipped_curr_noise_level]
+ alpha_next = torch.where(
+ next_noise_level < 0,
+ torch.ones_like(next_noise_level),
+ self.alphas_cumprod[next_noise_level],
+ )
+ sigma = torch.where(
+ next_noise_level < 0,
+ torch.zeros_like(next_noise_level),
+ self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
+ )
+ c = (1 - alpha_next - sigma**2).sqrt()
+
+ alpha_next = self.add_shape_channels(alpha_next)
+ c = self.add_shape_channels(c)
+ sigma = self.add_shape_channels(sigma)
+
+ if guidance_fn is not None:
+ with torch.enable_grad():
+ x = x.detach().requires_grad_()
+
+ model_pred = self.model_predictions(
+ x=x,
+ t=clipped_curr_noise_level,
+ action_cond=action_cond,
+ pose_cond=pose_cond,
+ current_frame=current_frame,
+ mode=mode,
+ reference_length=reference_length,
+ frame_idx=frame_idx
+ )
+
+ guidance_loss = guidance_fn(model_pred.pred_x_start)
+ grad = -torch.autograd.grad(
+ guidance_loss,
+ x,
+ )[0]
+
+ pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
+ x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
+
+ else:
+ # print(clipped_curr_noise_level)
+ model_pred = self.model_predictions(
+ x=x,
+ t=clipped_curr_noise_level,
+ action_cond=action_cond,
+ pose_cond=pose_cond,
+ current_frame=current_frame,
+ mode=mode,
+ reference_length=reference_length,
+ frame_idx=frame_idx
+ )
+ x_start = model_pred.pred_x_start
+ pred_noise = model_pred.pred_noise
+
+ noise = torch.randn_like(x)
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
+
+ x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
+
+ # only update frames where the noise level decreases
+ mask = curr_noise_level == next_noise_level
+ x_pred = torch.where(
+ self.add_shape_channels(mask),
+ orig_x,
+ x_pred,
+ )
+
+ return x_pred
diff --git a/algorithms/worldmem/models/dit.py b/algorithms/worldmem/models/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04e25e950208ad4baf70446541d16afea3d1dd7
--- /dev/null
+++ b/algorithms/worldmem/models/dit.py
@@ -0,0 +1,577 @@
+"""
+References:
+ - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
+ - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
+ - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
+"""
+
+from typing import Optional, Literal
+import torch
+from torch import nn
+from .rotary_embedding_torch import RotaryEmbedding
+from einops import rearrange
+from .attention import SpatialAxialAttention, TemporalAxialAttention, MemTemporalAxialAttention, MemFullAttention
+from timm.models.vision_transformer import Mlp
+from timm.layers.helpers import to_2tuple
+import math
+from collections import namedtuple
+from typing import Optional, Callable
+from .cameractrl_module import SimpleCameraPoseEncoder
+
+def modulate(x, shift, scale):
+ fixed_dims = [1] * len(shift.shape[1:])
+ shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
+ scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
+ while shift.dim() < x.dim():
+ shift = shift.unsqueeze(-2)
+ scale = scale.unsqueeze(-2)
+ return x * (1 + scale) + shift
+
+def gate(x, g):
+ fixed_dims = [1] * len(g.shape[1:])
+ g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
+ while g.dim() < x.dim():
+ g = g.unsqueeze(-2)
+ return g * x
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ img_height=256,
+ img_width=256,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ ):
+ super().__init__()
+ img_size = (img_height, img_width)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x, random_sample=False):
+ B, C, H, W = x.shape
+ assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+
+ x = self.proj(x)
+ if self.flatten:
+ x = rearrange(x, "B C H W -> B (H W) C")
+ else:
+ x = rearrange(x, "B C H W -> B H W C")
+ x = self.norm(x)
+ return x
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+ self.freq_type = freq_type
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+
+ if freq_type == 'time_step':
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
+ elif freq_type == 'spatial': # ~(-5 5)
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
+ elif freq_type == 'angle': # 0-360
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
+
+
+ args = t[:, None].float() * freqs[None]
+
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class SpatioTemporalDiTBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ reference_length,
+ mlp_ratio=4.0,
+ is_causal=True,
+ spatial_rotary_emb: Optional[RotaryEmbedding] = None,
+ temporal_rotary_emb: Optional[RotaryEmbedding] = None,
+ reference_rotary_emb=None,
+ use_plucker=False,
+ relative_embedding=False,
+ cond_only_on_qk=False,
+ use_reference_attention=False,
+ ref_mode='sequential'
+ ):
+ super().__init__()
+ self.is_causal = is_causal
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+
+ self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.s_attn = SpatialAxialAttention(
+ hidden_size,
+ heads=num_heads,
+ dim_head=hidden_size // num_heads,
+ rotary_emb=spatial_rotary_emb
+ )
+ self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.s_mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+ self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.t_attn = TemporalAxialAttention(
+ hidden_size,
+ heads=num_heads,
+ dim_head=hidden_size // num_heads,
+ is_causal=is_causal,
+ rotary_emb=temporal_rotary_emb,
+ reference_length=reference_length
+ )
+ self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.t_mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+ self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ self.use_reference_attention = use_reference_attention
+ if self.use_reference_attention:
+ self.r_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ref_type = "full_ref"
+ if self.ref_type == "temporal_ref":
+ self.r_attn = MemTemporalAxialAttention(
+ hidden_size,
+ heads=num_heads,
+ dim_head=hidden_size // num_heads,
+ is_causal=is_causal,
+ rotary_emb=None
+ )
+ elif self.ref_type == "full_ref":
+ self.r_attn = MemFullAttention(
+ hidden_size,
+ heads=num_heads,
+ dim_head=hidden_size // num_heads,
+ is_causal=is_causal,
+ rotary_emb=reference_rotary_emb,
+ reference_length=reference_length
+ )
+ self.r_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.r_mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+ self.r_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ self.use_plucker = use_plucker
+ if use_plucker:
+ self.pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
+ self.temporal_pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
+
+ self.reference_length = reference_length
+ self.relative_embedding = relative_embedding
+ self.cond_only_on_qk = cond_only_on_qk
+
+ self.ref_mode = ref_mode
+
+ if self.ref_mode == 'parallel':
+ self.parallel_map = nn.Linear(hidden_size, hidden_size)
+
+ def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
+ pose_cond=None, mode="training", c_action_cond=None, reference_length=None):
+ B, T, H, W, D = x.shape
+
+ # spatial block
+
+ s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
+ x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
+
+ # temporal block
+ if c_action_cond is not None:
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
+ else:
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
+
+ x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
+ x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
+
+ if self.ref_mode == 'sequential':
+ x = x_t
+
+ # memory block
+ relative_embedding = self.relative_embedding # and mode == "training"
+
+ if self.use_reference_attention:
+ r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
+
+ if pose_cond is not None:
+ if self.use_plucker:
+ input_cond = self.pose_cond_mlp(pose_cond)
+
+ if relative_embedding:
+ n_frames = x.shape[1] - reference_length
+ x1_relative_embedding = []
+ r_shift_msa_relative_embedding = []
+ r_scale_msa_relative_embedding = []
+ for i in range(n_frames):
+ x1_relative_embedding.append(torch.cat([x[:,i:i+1], x[:, -reference_length:]], dim=1).clone())
+ r_shift_msa_relative_embedding.append(torch.cat([r_shift_msa[:,i:i+1], r_shift_msa[:, -reference_length:]], dim=1).clone())
+ r_scale_msa_relative_embedding.append(torch.cat([r_scale_msa[:,i:i+1], r_scale_msa[:, -reference_length:]], dim=1).clone())
+ x1_zero_frame = torch.cat(x1_relative_embedding, dim=1)
+ r_shift_msa = torch.cat(r_shift_msa_relative_embedding, dim=1)
+ r_scale_msa = torch.cat(r_scale_msa_relative_embedding, dim=1)
+
+ # if current_frame == 18:
+ # import pdb;pdb.set_trace()
+
+ if self.cond_only_on_qk:
+ attn_input = x1_zero_frame
+ extra_condition = input_cond
+ else:
+ attn_input = input_cond + x1_zero_frame
+ extra_condition = None
+ else:
+ attn_input = input_cond + x
+ extra_condition = None
+ # print("input_cond2:", input_cond.abs().mean())
+ # print("c:", c.abs().mean())
+ # input_cond = x1
+
+ x = x + gate(self.r_attn(modulate(self.r_norm1(attn_input), r_shift_msa, r_scale_msa),
+ relative_embedding=relative_embedding,
+ extra_condition=extra_condition,
+ cond_only_on_qk=self.cond_only_on_qk,
+ reference_length=reference_length), r_gate_msa)
+ else:
+ # pose_cond *= 0
+ x = x + gate(self.r_attn(modulate(self.r_norm1(x+pose_cond[:,:,None, None]), r_shift_msa, r_scale_msa),
+ current_frame=current_frame, timestep=timestep,
+ is_last_block=is_last_block,
+ reference_length=reference_length), r_gate_msa)
+ else:
+ x = x + gate(self.r_attn(modulate(self.r_norm1(x), r_shift_msa, r_scale_msa), current_frame=current_frame, timestep=timestep,
+ is_last_block=is_last_block), r_gate_msa)
+
+ x = x + gate(self.r_mlp(modulate(self.r_norm2(x), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
+
+ if self.ref_mode == 'parallel':
+ x = x_t + self.parallel_map(x)
+
+ return x
+
+ # print((x1-x2).abs().sum())
+ # r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
+ # x2 = x1 + gate(self.r_attn(modulate(self.r_norm1(x_), r_shift_msa, r_scale_msa)), r_gate_msa)
+ # x2 = gate(self.r_mlp(modulate(self.r_norm2(x2), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
+ # x = x1 + x2
+
+ # print(x.mean())
+ # return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_h=18,
+ input_w=32,
+ patch_size=2,
+ in_channels=16,
+ hidden_size=1024,
+ depth=12,
+ num_heads=16,
+ mlp_ratio=4.0,
+ action_cond_dim=25,
+ pose_cond_dim=4,
+ max_frames=32,
+ reference_length=8,
+ use_plucker=False,
+ relative_embedding=False,
+ cond_only_on_qk=False,
+ use_reference_attention=False,
+ add_frame_timestep_embedder=False,
+ ref_mode='sequential'
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.max_frames = max_frames
+
+ self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.add_frame_timestep_embedder = add_frame_timestep_embedder
+ if self.add_frame_timestep_embedder:
+ self.frame_timestep_embedder = TimestepEmbedder(hidden_size)
+
+ frame_h, frame_w = self.x_embedder.grid_size
+
+ self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
+ self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
+ # self.reference_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
+ self.reference_rotary_emb = None
+
+ self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
+
+ # self.pose_cond = nn.Linear(pose_cond_dim, hidden_size) if pose_cond_dim > 0 else nn.Identity()
+
+ self.use_plucker = use_plucker
+ if not self.use_plucker:
+ self.position_embedder = TimestepEmbedder(hidden_size, freq_type='spatial')
+ self.angle_embedder = TimestepEmbedder(hidden_size, freq_type='angle')
+ else:
+ self.pose_embedder = SimpleCameraPoseEncoder(c_in=6, c_out=hidden_size)
+
+ self.blocks = nn.ModuleList(
+ [
+ SpatioTemporalDiTBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ is_causal=True,
+ reference_length=reference_length,
+ spatial_rotary_emb=self.spatial_rotary_emb,
+ temporal_rotary_emb=self.temporal_rotary_emb,
+ reference_rotary_emb=self.reference_rotary_emb,
+ use_plucker=self.use_plucker,
+ relative_embedding=relative_embedding,
+ cond_only_on_qk=cond_only_on_qk,
+ use_reference_attention=use_reference_attention,
+ ref_mode=ref_mode
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.use_reference_attention = use_reference_attention
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ if self.use_reference_attention:
+ if not self.use_plucker:
+ nn.init.normal_(self.position_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.position_embedder.mlp[2].weight, std=0.02)
+
+ nn.init.normal_(self.angle_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.angle_embedder.mlp[2].weight, std=0.02)
+
+ if self.add_frame_timestep_embedder:
+ nn.init.normal_(self.frame_timestep_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.frame_timestep_embedder.mlp[2].weight, std=0.02)
+
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
+
+ if self.use_plucker and self.use_reference_attention:
+ nn.init.constant_(block.pose_cond_mlp.weight, 0)
+ nn.init.constant_(block.pose_cond_mlp.bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, H, W, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = x.shape[1]
+ w = x.shape[2]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
+ return imgs
+
+ def forward(self, x, t, action_cond=None, pose_cond=None, current_frame=None, mode=None,
+ reference_length=None, frame_idx=None):
+ """
+ Forward pass of DiT.
+ x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (B, T,) tensor of diffusion timesteps
+ """
+
+ B, T, C, H, W = x.shape
+
+ # add spatial embeddings
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
+ # restore shape
+ x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
+ # embed noise steps
+ t = rearrange(t, "b t -> (b t)")
+
+ c_t = self.t_embedder(t) # (N, D)
+ c = c_t.clone()
+ c = rearrange(c, "(b t) d -> b t d", t=T)
+
+ if torch.is_tensor(action_cond):
+ try:
+ c_action_cond = c + self.external_cond(action_cond)
+ except:
+ import pdb;pdb.set_trace()
+ else:
+ c_action_cond = None
+
+ if torch.is_tensor(pose_cond):
+ if not self.use_plucker:
+ pose_cond = pose_cond.to(action_cond.dtype)
+ b_, t_, d_ = pose_cond.shape
+ pos_emb = self.position_embedder(rearrange(pose_cond[...,:3], "b t d -> (b t d)"))
+ angle_emb = self.angle_embedder(rearrange(pose_cond[...,3:], "b t d -> (b t d)"))
+ pos_emb = rearrange(pos_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=3).sum(-2)
+ angle_emb = rearrange(angle_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=2).sum(-2)
+ pc = pos_emb + angle_emb
+ else:
+ pose_cond = pose_cond[:, :, ::40, ::40]
+ # pc = self.pose_embedder(pose_cond)[0]
+ # pc = pc.permute(0,2,3,4,1)
+ pc = self.pose_embedder(pose_cond)
+ pc = pc.permute(1,0,2,3,4)
+
+ if torch.is_tensor(frame_idx) and self.add_frame_timestep_embedder:
+ bb = frame_idx.shape[1]
+ frame_idx = rearrange(frame_idx, "t b -> (b t)")
+ frame_idx = self.frame_timestep_embedder(frame_idx)
+ frame_idx = rearrange(frame_idx, "(b t) d -> b t d", b=bb)
+ pc = pc + frame_idx[:, :, None, None]
+
+ # pc = pc + rearrange(c_t.clone(), "(b t) d -> b t d", t=T)[:,:,None,None] # add time condition for different timestep scaling
+ else:
+ pc = None
+
+ for i, block in enumerate(self.blocks):
+ x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
+ pose_cond=pc, mode=mode, c_action_cond=c_action_cond, reference_length=reference_length) # (N, T, H, W, D)
+ x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
+ # unpatchify
+ x = rearrange(x, "b t h w d -> (b t) h w d")
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+ x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
+
+ # print("self.blocks[0].pose_cond_mlp.weight:", self.blocks[0].pose_cond_mlp.weight)
+ # print("self.blocks[0].r_adaLN_modulation[1].weight:", self.blocks[0].r_adaLN_modulation[1].weight)
+ # print("self.blocks[0].t_adaLN_modulation[1].weight:", self.blocks[0].t_adaLN_modulation[1].weight)
+
+ return x
+
+
+def DiT_S_2(action_cond_dim, pose_cond_dim, reference_length,
+use_plucker, relative_embedding,
+cond_only_on_qk, use_reference_attention, add_frame_timestep_embedder,
+ref_mode):
+ return DiT(
+ patch_size=2,
+ hidden_size=1024,
+ depth=16,
+ num_heads=16,
+ action_cond_dim=action_cond_dim,
+ pose_cond_dim=pose_cond_dim,
+ reference_length=reference_length,
+ use_plucker=use_plucker,
+ relative_embedding=relative_embedding,
+ cond_only_on_qk=cond_only_on_qk,
+ use_reference_attention=use_reference_attention,
+ add_frame_timestep_embedder=add_frame_timestep_embedder,
+ ref_mode=ref_mode
+ )
+
+
+DiT_models = {"DiT-S/2": DiT_S_2}
diff --git a/algorithms/worldmem/models/pose_prediction.py b/algorithms/worldmem/models/pose_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a390b45c5d076b8ca4c376859b8ce9e08348438
--- /dev/null
+++ b/algorithms/worldmem/models/pose_prediction.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class PosePredictionNet(nn.Module):
+ def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
+ super(PosePredictionNet, self).__init__()
+
+ self.cnn = nn.Sequential(
+ nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(),
+ nn.AdaptiveAvgPool2d((1, 1))
+ )
+
+ self.fc_img = nn.Linear(128, img_feat_dim)
+
+ self.mlp_motion = nn.Sequential(
+ nn.Linear(pose_dim + action_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU()
+ )
+
+ self.fc_out = nn.Sequential(
+ nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, pose_dim)
+ )
+
+ def forward(self, img, action, pose):
+ img_feat = self.cnn(img).view(img.size(0), -1)
+ img_feat = self.fc_img(img_feat)
+
+ motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
+ fused_feat = torch.cat([img_feat, motion_feat], dim=1)
+ pose_next_pred = self.fc_out(fused_feat)
+
+ return pose_next_pred
\ No newline at end of file
diff --git a/algorithms/worldmem/models/rotary_embedding_torch.py b/algorithms/worldmem/models/rotary_embedding_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9af591e49e8cc367e2789e939e889b210e48c7c
--- /dev/null
+++ b/algorithms/worldmem/models/rotary_embedding_torch.py
@@ -0,0 +1,302 @@
+"""
+Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
+"""
+
+from __future__ import annotations
+from math import pi, log
+
+import torch
+from torch.nn import Module, ModuleList
+from torch.amp import autocast
+from torch import nn, einsum, broadcast_tensors, Tensor
+
+from einops import rearrange, repeat
+
+from typing import Literal
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+# broadcat, as tortoise-tts was using it
+
+
+def broadcat(tensors, dim=-1):
+ broadcasted_tensors = broadcast_tensors(*tensors)
+ return torch.cat(broadcasted_tensors, dim=dim)
+
+
+# rotary embedding helper functions
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+@autocast("cuda", enabled=False)
+def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
+ dtype = t.dtype
+
+ if t.ndim == 3:
+ seq_len = t.shape[seq_dim]
+ freqs = freqs[-seq_len:]
+
+ rot_dim = freqs.shape[-1]
+ end_index = start_index + rot_dim
+
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+
+ # Split t into three parts: left, middle (to be transformed), and right
+ t_left = t[..., :start_index]
+ t_middle = t[..., start_index:end_index]
+ t_right = t[..., end_index:]
+
+ # Apply rotary embeddings without modifying t in place
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
+
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
+
+ return out.type(dtype)
+
+
+# learned rotation helpers
+
+
+def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
+ if exists(freq_ranges):
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
+
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
+ return apply_rotary_emb(rotations, t, start_index=start_index)
+
+
+# classes
+
+
+class RotaryEmbedding(Module):
+ def __init__(
+ self,
+ dim,
+ custom_freqs: Tensor | None = None,
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ learned_freq=False,
+ use_xpos=False,
+ xpos_scale_base=512,
+ interpolate_factor=1.0,
+ theta_rescale_factor=1.0,
+ seq_before_head_dim=False,
+ cache_if_possible=True,
+ cache_max_seq_len=8192,
+ ):
+ super().__init__()
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
+
+ self.freqs_for = freqs_for
+
+ if exists(custom_freqs):
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "spacetime":
+ time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+
+ if freqs_for == "spacetime":
+ self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq)
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
+
+ self.cache_if_possible = cache_if_possible
+ self.cache_max_seq_len = cache_max_seq_len
+
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
+ self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False)
+
+ self.learned_freq = learned_freq
+
+ # dummy for device
+
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
+
+ # default sequence dimension
+
+ self.seq_before_head_dim = seq_before_head_dim
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
+
+ # interpolation factors
+
+ assert interpolate_factor >= 1.0
+ self.interpolate_factor = interpolate_factor
+
+ # xpos
+
+ self.use_xpos = use_xpos
+
+ if not use_xpos:
+ return
+
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
+ self.scale_base = xpos_scale_base
+
+ self.register_buffer("scale", scale, persistent=False)
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
+ self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False)
+
+ # add apply_rotary_emb as static method
+
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
+
+ @property
+ def device(self):
+ return self.dummy.device
+
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
+
+ def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None):
+ seq_dim = default(seq_dim, self.default_seq_dim)
+
+ assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
+
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
+
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
+
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset)
+
+ if seq_dim == -3:
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
+
+ return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
+
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
+ dtype, device, seq_dim = (
+ q.dtype,
+ q.device,
+ default(seq_dim, self.default_seq_dim),
+ )
+
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
+ assert q_len <= k_len
+
+ q_scale = k_scale = 1.0
+
+ if self.use_xpos:
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
+
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
+ k_scale = self.get_scale(seq).type(dtype)
+
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
+
+ rotated_q = rotated_q.type(q.dtype)
+ rotated_k = rotated_k.type(k.dtype)
+
+ return rotated_q, rotated_k
+
+ def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None):
+ seq_dim = default(seq_dim, self.default_seq_dim)
+
+ assert self.use_xpos
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
+
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
+
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len)
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
+
+ if seq_dim == -3:
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
+ scale = rearrange(scale, "n d -> n 1 d")
+
+ rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim)
+ rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim)
+
+ rotated_q = rotated_q.type(q.dtype)
+ rotated_k = rotated_k.type(k.dtype)
+
+ return rotated_q, rotated_k
+
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
+ assert self.use_xpos
+
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
+
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item():
+ return self.cached_scales[offset : (offset + seq_len)]
+
+ scale = 1.0
+ if self.use_xpos:
+ power = (t - len(t) // 2) / self.scale_base
+ scale = self.scale ** rearrange(power, "n -> n 1")
+ scale = repeat(scale, "n d -> n (d r)", r=2)
+
+ if should_cache and offset == 0:
+ self.cached_scales[:seq_len] = scale.detach()
+ self.cached_scales_seq_len.copy_(seq_len)
+
+ return scale
+
+ def get_axial_freqs(self, *dims):
+ Colon = slice(None)
+ all_freqs = []
+
+ for ind, dim in enumerate(dims):
+ # only allow pixel freqs for last two dimensions
+ use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2
+ if use_pixel:
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
+ else:
+ pos = torch.arange(dim, device=self.device)
+
+ if self.freqs_for == "spacetime" and not use_pixel:
+ seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim)
+ else:
+ seq_freqs = self.forward(pos, self.freqs, seq_len=dim)
+
+ all_axis = [None] * len(dims)
+ all_axis[ind] = Colon
+
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
+ all_freqs.append(seq_freqs[new_axis_slice])
+
+ all_freqs = broadcast_tensors(*all_freqs)
+ return torch.cat(all_freqs, dim=-1)
+
+ @autocast("cuda", enabled=False)
+ def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0):
+ should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len
+
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item():
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
+
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+
+ if should_cache and offset == 0:
+ self.cached_freqs[:seq_len] = freqs.detach()
+ self.cached_freqs_seq_len.copy_(seq_len)
+
+ return freqs
diff --git a/algorithms/worldmem/models/utils.py b/algorithms/worldmem/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e6f8a9801649421fe5be46879dc74f5db09a12
--- /dev/null
+++ b/algorithms/worldmem/models/utils.py
@@ -0,0 +1,163 @@
+"""
+Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
+Action format derived from VPT https://github.com/openai/Video-Pre-Training
+Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
+"""
+
+import math
+import torch
+from torch import nn
+from torchvision.io import read_image, read_video
+from torchvision.transforms.functional import resize
+from einops import rearrange
+from typing import Mapping, Sequence
+from einops import rearrange, parse_shape
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+def extract(a, t, x_shape):
+ f, b = t.shape
+ out = a[t]
+ return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
+
+
+def linear_beta_schedule(timesteps):
+ """
+ linear schedule, proposed in original ddpm paper
+ """
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ steps = timesteps + 1
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0, 0.999)
+
+
+
+def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
+ """
+ sigmoid schedule
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
+ better for images > 64x64, when used during training
+ """
+ steps = timesteps + 1
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
+ v_start = torch.tensor(start / tau).sigmoid()
+ v_end = torch.tensor(end / tau).sigmoid()
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0, 0.999)
+
+
+ACTION_KEYS = [
+ "inventory",
+ "ESC",
+ "hotbar.1",
+ "hotbar.2",
+ "hotbar.3",
+ "hotbar.4",
+ "hotbar.5",
+ "hotbar.6",
+ "hotbar.7",
+ "hotbar.8",
+ "hotbar.9",
+ "forward",
+ "back",
+ "left",
+ "right",
+ "cameraX",
+ "cameraY",
+ "jump",
+ "sneak",
+ "sprint",
+ "swapHands",
+ "attack",
+ "use",
+ "pickItem",
+ "drop",
+]
+
+
+def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
+ actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
+ for i, current_actions in enumerate(actions):
+ for j, action_key in enumerate(ACTION_KEYS):
+ if action_key.startswith("camera"):
+ if action_key == "cameraX":
+ value = current_actions["camera"][0]
+ elif action_key == "cameraY":
+ value = current_actions["camera"][1]
+ else:
+ raise ValueError(f"Unknown camera action key: {action_key}")
+ max_val = 20
+ bin_size = 0.5
+ num_buckets = int(max_val / bin_size)
+ value = (value - num_buckets) / num_buckets
+ assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
+ else:
+ value = current_actions[action_key]
+ assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
+ actions_one_hot[i, j] = value
+
+ return actions_one_hot
+
+
+IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
+VIDEO_EXTENSIONS = {"mp4"}
+
+
+def load_prompt(path, video_offset=None, n_prompt_frames=1):
+ if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
+ print("prompt is image; ignoring video_offset and n_prompt_frames")
+ prompt = read_image(path)
+ # add frame dimension
+ prompt = rearrange(prompt, "c h w -> 1 c h w")
+ elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
+ prompt = read_video(path, pts_unit="sec")[0]
+ if video_offset is not None:
+ prompt = prompt[video_offset:]
+ prompt = prompt[:n_prompt_frames]
+ else:
+ raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
+ assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
+ prompt = resize(prompt, (360, 640))
+ # add batch dimension
+ prompt = rearrange(prompt, "t c h w -> 1 t c h w")
+ prompt = prompt.float() / 255.0
+ return prompt
+
+
+def load_actions(path, action_offset=None):
+ if path.endswith(".actions.pt"):
+ actions = one_hot_actions(torch.load(path))
+ elif path.endswith(".one_hot_actions.pt"):
+ actions = torch.load(path, weights_only=True)
+ else:
+ raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
+ if action_offset is not None:
+ actions = actions[action_offset:]
+ actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
+ # add batch dimension
+ actions = rearrange(actions, "t d -> 1 t d")
+ return actions
diff --git a/algorithms/worldmem/models/vae.py b/algorithms/worldmem/models/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cad52b41fd533d4ecdcc964a106c03b170cea64
--- /dev/null
+++ b/algorithms/worldmem/models/vae.py
@@ -0,0 +1,359 @@
+"""
+References:
+ - VQGAN: https://github.com/CompVis/taming-transformers
+ - MAE: https://github.com/facebookresearch/mae
+"""
+
+import numpy as np
+import math
+import functools
+from collections import namedtuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from timm.models.vision_transformer import Mlp
+from timm.layers.helpers import to_2tuple
+from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
+from .dit import PatchEmbed
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False, dim=1):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
+ if dim == 1:
+ self.dims = [1, 2, 3]
+ elif dim == 2:
+ self.dims = [1, 2]
+ else:
+ raise NotImplementedError
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def mode(self):
+ return self.mean
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ frame_height,
+ frame_width,
+ qkv_bias=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.frame_height = frame_height
+ self.frame_width = frame_width
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ rotary_freqs = RotaryEmbedding(
+ dim=head_dim // 4,
+ freqs_for="pixel",
+ max_freq=frame_height * frame_width,
+ ).get_axial_freqs(frame_height, frame_width)
+ self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ assert N == self.frame_height * self.frame_width
+
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
+
+ q = rearrange(
+ q,
+ "b (H W) (h d) -> b h H W d",
+ H=self.frame_height,
+ W=self.frame_width,
+ h=self.num_heads,
+ )
+ k = rearrange(
+ k,
+ "b (H W) (h d) -> b h H W d",
+ H=self.frame_height,
+ W=self.frame_width,
+ h=self.num_heads,
+ )
+ v = rearrange(
+ v,
+ "b (H W) (h d) -> b h H W d",
+ H=self.frame_height,
+ W=self.frame_width,
+ h=self.num_heads,
+ )
+
+ q = apply_rotary_emb(self.rotary_freqs, q)
+ k = apply_rotary_emb(self.rotary_freqs, k)
+
+ q = rearrange(q, "b h H W d -> b h (H W) d")
+ k = rearrange(k, "b h H W d -> b h (H W) d")
+ v = rearrange(v, "b h H W d -> b h (H W) d")
+
+ x = F.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b h N d -> b N (h d)")
+
+ x = self.proj(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ frame_height,
+ frame_width,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ attn_causal=False,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads,
+ frame_height,
+ frame_width,
+ qkv_bias=qkv_bias,
+ )
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ )
+
+ def forward(self, x):
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(
+ self,
+ latent_dim,
+ input_height=256,
+ input_width=256,
+ patch_size=16,
+ enc_dim=768,
+ enc_depth=6,
+ enc_heads=12,
+ dec_dim=768,
+ dec_depth=6,
+ dec_heads=12,
+ mlp_ratio=4.0,
+ norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
+ use_variational=True,
+ **kwargs,
+ ):
+ super().__init__()
+ self.input_height = input_height
+ self.input_width = input_width
+ self.patch_size = patch_size
+ self.seq_h = input_height // patch_size
+ self.seq_w = input_width // patch_size
+ self.seq_len = self.seq_h * self.seq_w
+ self.patch_dim = 3 * patch_size**2
+
+ self.latent_dim = latent_dim
+ self.enc_dim = enc_dim
+ self.dec_dim = dec_dim
+
+ # patch
+ self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
+
+ # encoder
+ self.encoder = nn.ModuleList(
+ [
+ AttentionBlock(
+ enc_dim,
+ enc_heads,
+ self.seq_h,
+ self.seq_w,
+ mlp_ratio,
+ qkv_bias=True,
+ norm_layer=norm_layer,
+ )
+ for i in range(enc_depth)
+ ]
+ )
+ self.enc_norm = norm_layer(enc_dim)
+
+ # bottleneck
+ self.use_variational = use_variational
+ mult = 2 if self.use_variational else 1
+ self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
+ self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
+
+ # decoder
+ self.decoder = nn.ModuleList(
+ [
+ AttentionBlock(
+ dec_dim,
+ dec_heads,
+ self.seq_h,
+ self.seq_w,
+ mlp_ratio,
+ qkv_bias=True,
+ norm_layer=norm_layer,
+ )
+ for i in range(dec_depth)
+ ]
+ )
+ self.dec_norm = norm_layer(dec_dim)
+ self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
+
+ # initialize this weight first
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # initialization
+ # initialize nn.Linear and nn.LayerNorm
+ self.apply(self._init_weights)
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0.0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0.0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def patchify(self, x):
+ # patchify
+ bsz, _, h, w = x.shape
+ x = x.reshape(
+ bsz,
+ 3,
+ self.seq_h,
+ self.patch_size,
+ self.seq_w,
+ self.patch_size,
+ ).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
+ x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w]
+ x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp]
+ return x
+
+ def unpatchify(self, x):
+ bsz = x.shape[0]
+ # unpatchify
+ x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
+ x = x.reshape(
+ bsz,
+ 3,
+ self.patch_size,
+ self.patch_size,
+ self.seq_h,
+ self.seq_w,
+ ).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
+ x = x.reshape(
+ bsz,
+ 3,
+ self.input_height,
+ self.input_width,
+ ) # [b, c, hxp, wxp]
+ return x
+
+ def encode(self, x):
+ # patchify
+ x = self.patch_embed(x)
+
+ # encoder
+ for blk in self.encoder:
+ x = blk(x)
+ x = self.enc_norm(x)
+
+ # bottleneck
+ moments = self.quant_conv(x)
+ if not self.use_variational:
+ moments = torch.cat((moments, torch.zeros_like(moments)), 2)
+ posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2)
+ return posterior
+
+ def decode(self, z):
+ # bottleneck
+ z = self.post_quant_conv(z)
+
+ # decoder
+ for blk in self.decoder:
+ z = blk(z)
+ z = self.dec_norm(z)
+
+ # predictor
+ z = self.predictor(z)
+
+ # unpatchify
+ dec = self.unpatchify(z)
+ return dec
+
+ def autoencode(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if self.use_variational and sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior, z
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def forward(self, inputs, labels, split="train"):
+ rec, post, latent = self.autoencode(inputs)
+ return rec, post, latent
+
+ def get_last_layer(self):
+ return self.predictor.weight
+
+
+def ViT_L_20_Shallow_Encoder(**kwargs):
+ if "latent_dim" in kwargs:
+ latent_dim = kwargs.pop("latent_dim")
+ else:
+ latent_dim = 16
+ return AutoencoderKL(
+ latent_dim=latent_dim,
+ patch_size=20,
+ enc_dim=1024,
+ enc_depth=6,
+ enc_heads=16,
+ dec_dim=1024,
+ dec_depth=12,
+ dec_heads=16,
+ input_height=360,
+ input_width=640,
+ **kwargs,
+ )
+
+
+VAE_models = {
+ "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
+}
diff --git a/algorithms/worldmem/pose_prediction.py b/algorithms/worldmem/pose_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd6148dc7572f1f74a2cbf12293ad698111529a3
--- /dev/null
+++ b/algorithms/worldmem/pose_prediction.py
@@ -0,0 +1,374 @@
+from omegaconf import DictConfig
+import torch
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+from algorithms.common.metrics import (
+ FrechetInceptionDistance,
+ LearnedPerceptualImagePatchSimilarity,
+ FrechetVideoDistance,
+)
+from .df_base import DiffusionForcingBase
+from utils.logging_utils import log_video, get_validation_metrics_for_videos
+from .models.vae import VAE_models
+from .models.dit import DiT_models
+from einops import rearrange
+from torch import autocast
+import numpy as np
+from tqdm import tqdm
+import torch.nn.functional as F
+from .models.pose_prediction import PosePredictionNet
+import torchvision.transforms.functional as TF
+import random
+from torchvision.transforms import InterpolationMode
+from PIL import Image
+import math
+from packaging import version as pver
+import torch.distributed as dist
+import matplotlib.pyplot as plt
+
+import torch
+import math
+import wandb
+
+import torch.nn as nn
+from algorithms.common.base_pytorch_algo import BasePytorchAlgo
+
+class PosePrediction(BasePytorchAlgo):
+
+ def __init__(self, cfg: DictConfig):
+
+ super().__init__(cfg)
+
+ def _build_model(self):
+ self.pose_prediction_model = PosePredictionNet()
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
+ self.vae = vae.eval()
+
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
+ xs, conditions, pose_conditions= batch
+ pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
+ xs = self.encode(xs)
+
+ b,f,c,h,w = xs.shape
+ xs = xs[:,:-1].reshape(-1, c, h, w)
+ conditions = conditions[:,1:].reshape(-1, 25)
+ offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
+ pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
+ offset_gt = offset_gt.reshape(-1, 5)
+ offset_gt[:, 3][offset_gt[:, 3]==23] = -1
+ offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
+ offset_gt[:, 4][offset_gt[:, 4]==23] = -1
+ offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
+
+ offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
+ criterion = nn.MSELoss()
+ loss = criterion(offset_pred, offset_gt)
+ if batch_idx % 200 == 0:
+ self.log("training/loss", loss.cpu())
+ output_dict = {
+ "loss": loss}
+ return output_dict
+
+ def encode(self, x):
+ # vae encoding
+ B = x.shape[1]
+ T = x.shape[0]
+ H, W = x.shape[-2:]
+ scaling_factor = 0.07843137255
+
+ x = rearrange(x, "t b c h w -> (t b) c h w")
+ with torch.no_grad():
+ with autocast("cuda", dtype=torch.half):
+ x = self.vae.encode(x * 2 - 1).mean * scaling_factor
+ x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
+ # x = x[:, :n_prompt_frames]
+ return x
+
+ def decode(self, x):
+ total_frames = x.shape[0]
+ scaling_factor = 0.07843137255
+ x = rearrange(x, "t b c h w -> (t b) (h w) c")
+ with torch.no_grad():
+ with autocast("cuda", dtype=torch.half):
+ x = (self.vae.decode(x / scaling_factor) + 1) / 2
+
+ x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
+ return x
+
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
+ xs, conditions, pose_conditions= batch
+ pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
+ xs = self.encode(xs)
+
+ b,f,c,h,w = xs.shape
+ xs = xs[:,:-1].reshape(-1, c, h, w)
+ conditions = conditions[:,1:].reshape(-1, 25)
+ offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
+ pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
+ offset_gt = offset_gt.reshape(-1, 5)
+ offset_gt[:, 3][offset_gt[:, 3]==23] = -1
+ offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
+ offset_gt[:, 4][offset_gt[:, 4]==23] = -1
+ offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
+
+ offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
+
+ criterion = nn.MSELoss()
+ loss = criterion(offset_pred, offset_gt)
+
+ if batch_idx % 200 == 0:
+ self.log("validation/loss", loss.cpu())
+ output_dict = {
+ "loss": loss}
+ return
+
+ @torch.no_grad()
+ def interactive(self, batch, context_frames, device):
+ with torch.cuda.amp.autocast():
+ condition_similar_length = self.condition_similar_length
+ # xs_raw, conditions, pose_conditions, c2w_mat, masks, frame_idx = self._preprocess_batch(batch)
+
+ first_frame, new_conditions, new_pose_conditions, new_c2w_mat, new_frame_idx = batch
+
+ if self.frames is None:
+ first_frame_encode = self.encode(first_frame[None, None].to(device))
+ self.frames = first_frame_encode.to(device)
+ self.actions = new_conditions[None, None].to(device)
+ self.poses = new_pose_conditions[None, None].to(device)
+ self.memory_c2w = new_c2w_mat[None, None].to(device)
+ self.frame_idx = torch.tensor([[new_frame_idx]]).to(device)
+ return first_frame
+ else:
+ self.actions = torch.cat([self.actions, new_conditions[None, None].to(device)])
+ self.poses = torch.cat([self.poses, new_pose_conditions[None, None].to(device)])
+ self.memory_c2w = torch.cat([self.memory_c2w, new_c2w_mat[None, None].to(device)])
+ self.frame_idx = torch.cat([self.frame_idx, torch.tensor([[new_frame_idx]]).to(device)])
+
+ conditions = self.actions.clone()
+ pose_conditions = self.poses.clone()
+ c2w_mat = self.memory_c2w .clone()
+ frame_idx = self.frame_idx.clone()
+
+
+ curr_frame = 0
+ horizon = 1
+ batch_size = 1
+ n_frames = curr_frame + horizon
+ # context
+ n_context_frames = context_frames // self.frame_stack
+ xs_pred = self.frames[:n_context_frames].clone()
+ curr_frame += n_context_frames
+
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
+
+ # generation on frame
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
+
+ xs_pred = torch.cat([xs_pred, chunk], 0)
+
+ # sliding window: only input the last n_tokens frames
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
+
+ pbar.set_postfix(
+ {
+ "start": start_frame,
+ "end": curr_frame + horizon,
+ }
+ )
+
+ if condition_similar_length:
+
+ if curr_frame < condition_similar_length:
+ random_idx = [i for i in range(curr_frame)] + [0] * (condition_similar_length-curr_frame)
+ random_idx = np.repeat(np.array(random_idx)[:,None], xs_pred.shape[1], -1)
+ else:
+ num_samples = 10000
+ radius = 30
+ samples = torch.rand((num_samples, 1), device=pose_conditions.device)
+ angles = 2 * np.pi * torch.rand((num_samples,), device=pose_conditions.device)
+ # points = radius * torch.sqrt(samples) * torch.stack((torch.cos(angles), torch.sin(angles)), dim=1)
+
+ points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
+ points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
+ points += pose_conditions[curr_frame, :, :3][None]
+ fov_half_h = torch.tensor(105/2, device=pose_conditions.device)
+ fov_half_v = torch.tensor(75/2, device=pose_conditions.device)
+ # in_fov1 = is_inside_fov(points, pose_conditions[curr_frame, :, [0, 2]], pose_conditions[curr_frame, :, -1], fov_half)
+
+ in_fov1 = is_inside_fov_3d_hv(points, pose_conditions[curr_frame, :, :3],
+ pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
+ fov_half_h, fov_half_v)
+
+ in_fov_list = []
+ for pc in pose_conditions[:curr_frame]:
+ in_fov_list.append(is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1],
+ fov_half_h, fov_half_v))
+
+ in_fov_list = torch.stack(in_fov_list)
+ # v3
+ random_idx = []
+
+ for csl in range(self.condition_similar_length // 2):
+ overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
+ # mask = distance > (in_fov1.bool().sum(0) / 4)
+ #_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ # if csl > self.condition_similar_length:
+ # _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
+ # else:
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ # if curr_frame >=93:
+ # import pdb;pdb.set_trace()
+
+ # start_time = time.time()
+ cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
+ range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
+ cos_sim = cos_sim.mean((-2,-1))
+
+ mask_sim = cos_sim>0.9
+ in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
+
+ random_idx.append(r_idx)
+
+ for bi in range(conditions.shape[1]):
+ if len(torch.nonzero(conditions[:,bi,24] == 1))==0:
+ pass
+ else:
+ last_idx = torch.nonzero(conditions[:,bi,24] == 1)[-1]
+ in_fov_list[:last_idx,:,bi] = False
+
+ for csl in range(self.condition_similar_length // 2):
+ overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
+ # mask = distance > (in_fov1.bool().sum(0) / 4)
+ #_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ # if csl > self.condition_similar_length:
+ # _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
+ # else:
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
+ # _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
+
+ # if curr_frame >=93:
+ # import pdb;pdb.set_trace()
+
+ # start_time = time.time()
+ cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
+ range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
+ cos_sim = cos_sim.mean((-2,-1))
+
+ mask_sim = cos_sim>0.9
+ in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
+
+ random_idx.append(r_idx)
+
+ random_idx = torch.cat(random_idx).cpu()
+ condition_similar_length = len(random_idx)
+
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
+
+ if condition_similar_length:
+ # import pdb;pdb.set_trace()
+ padding = torch.zeros((condition_similar_length,) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
+ input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], padding], dim=0)
+ if self.pose_cond_dim:
+ # if not self.use_plucker:
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
+
+ if self.use_plucker:
+ if self.all_zero_frame:
+ frame_idx_list = []
+ input_pose_condition = []
+ for i in range(start_frame, curr_frame + horizon):
+ input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]]).clone(), 0, focal_length=self.focal_length, is_old_setting=self.old_setting).to(xs_pred.dtype))
+ frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]-frame_idx[i:i+1]]))
+ input_pose_condition = torch.cat(input_pose_condition)
+ frame_idx_list = torch.cat(frame_idx_list)
+
+ # print(frame_idx_list[:,0])
+ else:
+ # print(curr_frame-start_frame)
+ # input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
+ # import pdb;pdb.set_trace()
+ if self.last_frame_refer:
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[-1:]], dim=0).clone()
+ else:
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
+
+ if self.zero_curr:
+ # print("="*50)
+ input_pose_condition = convert_to_plucker(input_pose_condition, curr_frame-start_frame, focal_length=self.focal_length, is_old_setting=self.old_setting)
+ # input_pose_condition[:curr_frame-start_frame] = input_pose_condition[curr_frame-start_frame:curr_frame-start_frame+1]
+ # input_pose_condition = convert_to_plucker(input_pose_condition, -self.condition_similar_length-1, focal_length=self.focal_length)
+ else:
+ input_pose_condition = convert_to_plucker(input_pose_condition, -condition_similar_length, focal_length=self.focal_length, is_old_setting=self.old_setting)
+ frame_idx_list = None
+ else:
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
+ frame_idx_list = None
+ else:
+ input_condition = conditions[start_frame : curr_frame + horizon]
+ input_pose_condition = None
+ frame_idx_list = None
+
+ for m in range(scheduling_matrix.shape[0] - 1):
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
+ :, None
+ ].repeat(batch_size, axis=1)
+ to_noise_levels = np.concatenate(
+ (
+ np.zeros((curr_frame,), dtype=np.int64),
+ scheduling_matrix[m + 1],
+ )
+ )[
+ :, None
+ ].repeat(batch_size, axis=1)
+
+ if condition_similar_length:
+ from_noise_levels = np.concatenate([from_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
+ to_noise_levels = np.concatenate([to_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
+
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
+
+
+ if input_pose_condition is not None:
+ input_pose_condition = input_pose_condition.to(xs_pred.dtype)
+
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
+ xs_pred[start_frame:],
+ input_condition,
+ input_pose_condition,
+ from_noise_levels[start_frame:],
+ to_noise_levels[start_frame:],
+ current_frame=curr_frame,
+ mode="validation",
+ reference_length=condition_similar_length,
+ frame_idx=frame_idx_list
+ )
+
+ # if curr_frame > 14:
+ # import pdb;pdb.set_trace()
+
+ # if xs_pred_back is not None:
+ # xs_pred = torch.cat([xs_pred[:6], xs_pred_back[6:12], xs_pred[6:]], dim=0)
+
+ # import pdb;pdb.set_trace()
+ if condition_similar_length: # and curr_frame+1!=n_frames:
+ xs_pred = xs_pred[:-condition_similar_length]
+
+ curr_frame += horizon
+ pbar.update(horizon)
+
+ self.frames = torch.cat([self.frames, xs_pred[n_context_frames:]])
+
+ xs_pred = self.decode(xs_pred[n_context_frames:])
+
+ return xs_pred[-1,0].cpu()
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..bed324196f84af33576aeeaac19d41c4a19e6f9f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,365 @@
+import gradio as gr
+import time
+
+import sys
+import subprocess
+import time
+from pathlib import Path
+
+import hydra
+from omegaconf import DictConfig, OmegaConf
+from omegaconf.omegaconf import open_dict
+
+from utils.print_utils import cyan
+from utils.ckpt_utils import download_latest_checkpoint, is_run_id
+from utils.cluster_utils import submit_slurm_job
+from utils.distributed_utils import is_rank_zero
+import numpy as np
+import torch
+from datasets.video.minecraft_video_dataset import *
+import torchvision.transforms as transforms
+import cv2
+import subprocess
+from PIL import Image
+from datetime import datetime
+
+ACTION_KEYS = [
+ "inventory",
+ "ESC",
+ "hotbar.1",
+ "hotbar.2",
+ "hotbar.3",
+ "hotbar.4",
+ "hotbar.5",
+ "hotbar.6",
+ "hotbar.7",
+ "hotbar.8",
+ "hotbar.9",
+ "forward",
+ "back",
+ "left",
+ "right",
+ "cameraY",
+ "cameraX",
+ "jump",
+ "sneak",
+ "sprint",
+ "swapHands",
+ "attack",
+ "use",
+ "pickItem",
+ "drop",
+]
+
+# Mapping of input keys to action names
+KEY_TO_ACTION = {
+ "Q": ("forward", 1),
+ "E": ("back", 1),
+ "W": ("cameraY", -1),
+ "S": ("cameraY", 1),
+ "A": ("cameraX", -1),
+ "D": ("cameraX", 1),
+ "U": ("drop", 1),
+ "N": ("noop", 1),
+ "1": ("hotbar.1", 1),
+}
+
+def parse_input_to_tensor(input_str):
+ """
+ Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
+ of the corresponding action key.
+
+ Args:
+ input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
+
+ Returns:
+ torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
+ """
+ # Get the length of the input sequence
+ seq_len = len(input_str)
+
+ # Initialize a zero tensor of shape (seq_len, 25)
+ action_tensor = torch.zeros((seq_len, 25))
+
+ # Iterate through the input string and update the corresponding positions
+ for i, char in enumerate(input_str):
+ action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
+ if action and action in ACTION_KEYS:
+ index = ACTION_KEYS.index(action)
+ action_tensor[i, index] = value # Set the corresponding action index to 1
+
+ return action_tensor
+
+def load_image_as_tensor(image_path: str) -> torch.Tensor:
+ """
+ Load an image and convert it to a 0-1 normalized tensor.
+
+ Args:
+ image_path (str): Path to the image file.
+
+ Returns:
+ torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
+ """
+ if isinstance(image_path, str):
+ image = Image.open(image_path).convert("RGB") # Ensure it's RGB
+ else:
+ image = image_path
+ transform = transforms.Compose([
+ transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
+ ])
+ return transform(image)
+
+def run_local(cfg: DictConfig):
+ # delay some imports in case they are not needed in non-local envs for submission
+ from experiments import build_experiment
+
+ # Get yaml names
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
+
+ with open_dict(cfg):
+ if cfg_choice["experiment"] is not None:
+ cfg.experiment._name = cfg_choice["experiment"]
+ if cfg_choice["dataset"] is not None:
+ cfg.dataset._name = cfg_choice["dataset"]
+ if cfg_choice["algorithm"] is not None:
+ cfg.algorithm._name = cfg_choice["algorithm"]
+
+ # launch experiment
+ experiment = build_experiment(cfg, None, cfg.checkpoint_path)
+ return experiment.exec_interactive(cfg.experiment.tasks[0])
+
+memory_frames = []
+memory_curr_frame = 0
+input_history = ""
+ICE_PLAINS_IMAGE = "assets/ice_plains.png"
+DESERT_IMAGE = "assets/desert.png"
+SAVANNA_IMAGE = "assets/savanna.png"
+PLAINS_IMAGE = "assets/plans.png"
+PLACE_IMAGE = "assets/place.png"
+SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
+SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
+
+DEFAULT_IMAGE = ICE_PLAINS_IMAGE
+device = "cuda:0"
+
+def save_video(frames, path="output.mp4", fps=10):
+ h, w, _ = frames[0].shape
+ out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h))
+ for frame in frames:
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
+ out.release()
+
+ ffmpeg_cmd = [
+ "ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path
+ ]
+ subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ return path
+
+@hydra.main(
+ version_base=None,
+ config_path="configurations",
+ config_name="config",
+)
+def run(cfg: DictConfig):
+ algo = run_local(cfg)
+ algo.to("cuda:0")
+
+ actions = torch.zeros((1, 25))
+ poses = torch.zeros((1, 5))
+
+ memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
+
+ _ = algo.interactive(memory_frames[0],
+ actions[0],
+ poses[0],
+ memory_curr_frame,
+ device="cuda:0")
+
+ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
+ algo.sampling_timesteps = denoising_steps
+ algo.diffusion_model.sampling_timesteps = denoising_steps
+ sampling_timesteps_state = denoising_steps
+ print("set denoising steps to", algo.sampling_timesteps)
+ return sampling_timesteps_state
+
+
+ def update_image_and_log(keys):
+ actions = parse_input_to_tensor(keys)
+ global input_history
+ global memory_curr_frame
+ for i in range(len(actions)):
+ memory_curr_frame += 1
+ new_frame = algo.interactive(memory_frames[0],
+ actions[i],
+ None,
+ memory_curr_frame,
+ device="cuda:0")
+
+ memory_frames.append(new_frame)
+
+ out_video = torch.stack(memory_frames)
+ out_video = out_video.permute(0,2,3,1).numpy()
+ out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
+ out_video = (out_video * 255).astype(np.uint8)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ os.makedirs("outputs_gradio", exist_ok=True)
+ filename = f"outputs_gradio/{timestamp}.mp4"
+ save_video(out_video, filename)
+
+ input_history += keys
+ return out_video[-1], filename, input_history
+
+ def reset():
+ global memory_curr_frame
+ global input_history
+ global memory_frames
+
+ algo.reset()
+ memory_frames = []
+ memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
+ memory_curr_frame = 0
+ input_history = ""
+
+ _ = algo.interactive(memory_frames[0],
+ actions[0],
+ poses[0],
+ memory_curr_frame,
+ device="cuda:0")
+ return input_history, DEFAULT_IMAGE
+
+ def on_image_click(SELECTED_IMAGE):
+ global DEFAULT_IMAGE
+ DEFAULT_IMAGE = SELECTED_IMAGE
+ reset()
+ return SELECTED_IMAGE
+
+ css = """
+ h1 {
+ text-align: center;
+ display:block;
+ }
+ """
+
+ # update_image_and_log("W")
+ with gr.Blocks(css=css) as demo:
+ gr.Markdown(
+ """
+ # WORLDMEM: Long-term Consistent World Generation with Memory
+
+
+
+ """
+ )
+
+ with gr.Row(variant="panel"):
+ video_display = gr.Video(autoplay=True, loop=True)
+ image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
+
+ with gr.Row(variant="panel"):
+ with gr.Column(scale=2):
+ input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1)
+ log_output = gr.Textbox(label="History Log", interactive=False)
+ with gr.Column(scale=1):
+ slider = gr.Slider(minimum=10, maximum=50, value=algo.sampling_timesteps, step=1, label="Denoising Steps")
+ submit_button = gr.Button("Generate")
+ reset_btn = gr.Button("Reset")
+
+ sampling_timesteps_state = gr.State(algo.sampling_timesteps)
+
+ example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
+ "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEEAAAAAAAAAAAAAAAAAAAAAA"]
+
+ def set_action(action):
+ return action
+
+ gr.Markdown("### Action sequence examples.")
+ with gr.Row():
+ buttons = []
+ for action in example_actions[:2]:
+ with gr.Column(scale=len(action)):
+ buttons.append(gr.Button(action))
+ with gr.Row():
+ for action in example_actions[2:4]:
+ with gr.Column(scale=len(action)):
+ buttons.append(gr.Button(action))
+ with gr.Row():
+ for action in example_actions[4:5]:
+ with gr.Column(scale=len(action)):
+ buttons.append(gr.Button(action))
+
+ for button, action in zip(buttons, example_actions):
+ button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box)
+
+
+ gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
+
+ with gr.Row():
+ image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
+ image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
+ image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
+ image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
+ image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
+
+ gr.Markdown(
+ """
+ ## Instructions & Notes:
+
+ 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
+ 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
+ 3. Click **"Reset"** to clear the current sequence and start fresh.
+ 4. Action sequences can be composed using the following keys:
+ - W: turn up
+ - S: turn down
+ - A: turn left
+ - D: turn right
+ - Q: move forward
+ - E: move backward
+ - N: no-op (do nothing)
+ - 1: switch to hotbar 1
+ - U: use item
+ 5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed.
+ 6. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()!
+ 7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **zeqixiao1@gmail.com**.
+ """
+ )
+ # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
+ submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
+ reset_btn.click(reset, outputs=[log_output, image_display])
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
+
+ slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state)
+
+ # 允许公开访问
+ demo.launch(share=True)
+ demo.launch(server_name="0.0.0.0", server_port=30066)
+
+if __name__ == "__main__":
+ run() # pylint: disable=no-value-for-parameter
diff --git a/app.sh b/app.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9505359bdf7058398572b340ec468239d869affa
--- /dev/null
+++ b/app.sh
@@ -0,0 +1,50 @@
+wandb disabled
+# srun -p a6000_xgpan -w MICL-PanXGSvr2 --gres=gpu:1 --ntasks-per-node=1 --cpus-per-task=8 \
+export WANDB_API_KEY=a4f0741e80f509317597ad944a7292fabcb68bdf
+
+CHECKPOINT_PATH="checkpoints/diffusion_only.ckpt"
+
+python -m app +name=pumpkin \
+ algorithm=df_video_worldmemminecraft \
+ +checkpoint_path=$CHECKPOINT_PATH \
+ experiment.tasks=[interactive] \
+ dataset.validation_multiplier=1 \
+ dataset=video_minecraft \
+ +customized_load=true \
+ +dataset.n_frames_valid=100 \
+ +algorithm.n_tokens=8 \
+ +load_vae=false \
+ +load_t_to_r=false \
+ +zero_init_gate=false \
+ experiment.validation.batch_size=1 \
+ +algorithm.pose_cond_dim=5 \
+ +algorithm.condition_similar_length=8 \
+ +dataset.condition_similar_length=8 \
+ +algorithm.use_plucker=true \
+ +dataset.use_plucker=true \
+ +dataset.padding_pool=10 \
+ +dataset.focal_length=0.35 \
+ +algorithm.focal_length=0.35 \
+ +only_tune_refer=false \
+ +dataset.customized_validation=true \
+ +algorithm.customized_validation=true \
+ algorithm.context_frames=90 \
+ +algorithm.vis_gt=true \
+ +algorithm.relative_embedding=true \
+ dataset.save_dir=data/test_pumpkin \
+ +algorithm.log_video=true \
+ experiment.training.data.num_workers=4 \
+ experiment.validation.data.num_workers=4 \
+ +dataset.angle_range=30 \
+ +dataset.pos_range=0.5 \
+ +algorithm.cond_only_on_qk=true \
+ +algorithm.add_pose_embed=false \
+ +algorithm.use_domain_adapter=false \
+ +algorithm.use_reference_attention=true \
+ +algorithm.add_frame_timestep_embedder=true \
+ +dataset.add_frame_timestep_embedder=true \
+ experiment.validation.limit_batch=1 \
+ algorithm.diffusion.sampling_timesteps=20 \
+ +algorithm.is_interactive=true \
+ +vae_path=checkpoints/vae_only.ckpt \
+ +pose_predictor_path=checkpoints/pose_prediction_model_only.ckpt
diff --git a/configurations/README.md b/configurations/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e62d0478e30b5453f1b81cf1d5a44ef932137dc
--- /dev/null
+++ b/configurations/README.md
@@ -0,0 +1,7 @@
+# configurations
+
+We use [Hydra](https://hydra.cc/docs/intro/) to manage configurations. Change/Add the yaml files in this folder
+to change the default configurations. You can also override the default configurations by
+passing command line arguments.
+
+All configurations are automatically saved in wandb run.
\ No newline at end of file
diff --git a/configurations/algorithm/base_algo.yaml b/configurations/algorithm/base_algo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3a116a5d5147fb8aede0a857ff057677186b8a54
--- /dev/null
+++ b/configurations/algorithm/base_algo.yaml
@@ -0,0 +1,3 @@
+# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
+
+debug: ${debug} # inherited from configurations/config.yaml
diff --git a/configurations/algorithm/base_pytorch_algo.yaml b/configurations/algorithm/base_pytorch_algo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8870c21cd59ff4e5e88e7fc5d97fdfe120f93f5
--- /dev/null
+++ b/configurations/algorithm/base_pytorch_algo.yaml
@@ -0,0 +1,4 @@
+defaults:
+ - base_algo # inherits from configurations/algorithm/base_algo.yaml
+
+lr: ${experiment.training.lr}
diff --git a/configurations/algorithm/df_base.yaml b/configurations/algorithm/df_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..65be51cfe5837e3941a967955abb373cb40aaf76
--- /dev/null
+++ b/configurations/algorithm/df_base.yaml
@@ -0,0 +1,42 @@
+defaults:
+ - base_pytorch_algo
+
+# dataset-dependent configurations
+x_shape: ${dataset.observation_shape}
+frame_stack: 1
+frame_skip: 1
+data_mean: ${dataset.data_mean}
+data_std: ${dataset.data_std}
+external_cond_dim: 0 #${dataset.action_dim}
+context_frames: ${dataset.context_length}
+# training hyperparameters
+weight_decay: 1e-4
+warmup_steps: 10000
+optimizer_beta: [0.9, 0.999]
+# diffusion-related
+uncertainty_scale: 1
+guidance_scale: 0.0
+chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
+scheduling_matrix: autoregressive
+noise_level: random_all
+causal: True
+
+diffusion:
+ # training
+ objective: pred_x0
+ beta_schedule: cosine
+ schedule_fn_kwargs: {}
+ clip_noise: 20.0
+ use_snr: False
+ use_cum_snr: False
+ use_fused_snr: False
+ snr_clip: 5.0
+ cum_snr_decay: 0.98
+ timesteps: 1000
+ # sampling
+ sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
+ ddim_sampling_eta: 1.0
+ stabilization_level: 10
+ # architecture
+ architecture:
+ network_size: 64
diff --git a/configurations/algorithm/df_video_worldmemminecraft.yaml b/configurations/algorithm/df_video_worldmemminecraft.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d8c960d39895b4b288b9252b6804a2d766b3ca0
--- /dev/null
+++ b/configurations/algorithm/df_video_worldmemminecraft.yaml
@@ -0,0 +1,42 @@
+defaults:
+ - df_base
+
+n_frames: ${dataset.n_frames}
+frame_skip: ${dataset.frame_skip}
+metadata: ${dataset.metadata}
+
+# training hyperparameters
+weight_decay: 2e-3
+warmup_steps: 10000
+optimizer_beta: [0.9, 0.99]
+action_cond_dim: 25
+
+diffusion:
+ # training
+ beta_schedule: sigmoid
+ objective: pred_v
+ use_fused_snr: True
+ cum_snr_decay: 0.96
+ clip_noise: 20.
+ # sampling
+ sampling_timesteps: 20
+ ddim_sampling_eta: 0.0
+ stabilization_level: 15
+ # architecture
+ architecture:
+ network_size: 64
+ attn_heads: 4
+ attn_dim_head: 64
+ dim_mults: [1, 2, 4, 8]
+ resolution: ${dataset.resolution}
+ attn_resolutions: [16, 32, 64, 128]
+ use_init_temporal_attn: True
+ use_linear_attn: True
+ time_emb_type: rotary
+
+metrics:
+ # - fvd
+ # - fid
+ # - lpips
+
+_name: df_video_worldmemminecraft
\ No newline at end of file
diff --git a/configurations/algorithm/pose_prediction.yaml b/configurations/algorithm/pose_prediction.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b22625c0e8a34de154e4a6266d4fe929c8b57b01
--- /dev/null
+++ b/configurations/algorithm/pose_prediction.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - df_base
+
+n_frames: ${dataset.n_frames}
+frame_skip: ${dataset.frame_skip}
+metadata: ${dataset.metadata}
+
+# training hyperparameters
+weight_decay: 2e-3
+warmup_steps: 10000
+optimizer_beta: [0.9, 0.99]
+
+
+metrics:
+ # - fvd
+ # - fid
+ # - lpips
+
+_name: pose_prediction
\ No newline at end of file
diff --git a/configurations/config.yaml b/configurations/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f7ac1f4c5ca1954f5ad7c6f5e645f4c3e0d9ea6d
--- /dev/null
+++ b/configurations/config.yaml
@@ -0,0 +1,16 @@
+# configuration parsing starts here
+defaults:
+ - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
+ - dataset: video_minecraft_oasis # dataset yaml file name in configurations/dataset folder [fixme]
+ - algorithm: df_video # algorithm yaml file name in configurations/algorithm folder [fixme]
+ - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
+
+debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
+
+wandb:
+ entity: xizaoqu # wandb account name / organization name [fixme]
+ project: diffusion-forcing # wandb project name; if not provided, defaults to root folder name [fixme]
+ mode: online # set wandb logging to online, offline or dryrun
+
+resume: null # wandb run id to resume logging and loading checkpoint from
+load: null # wandb run id containing checkpoint or a path to a checkpoint file
diff --git a/configurations/dataset/base_dataset.yaml b/configurations/dataset/base_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e4fc0a9d01954d5968ce3438c5af00e97a8308ec
--- /dev/null
+++ b/configurations/dataset/base_dataset.yaml
@@ -0,0 +1,3 @@
+# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
+
+debug: ${debug} # inherited from configurations/config.yaml
diff --git a/configurations/dataset/base_video.yaml b/configurations/dataset/base_video.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b364d9b62dea195afc4f36dc23b842b9a039bc80
--- /dev/null
+++ b/configurations/dataset/base_video.yaml
@@ -0,0 +1,14 @@
+defaults:
+ - base_dataset
+
+metadata: "data/${dataset.name}/metadata.json"
+data_mean: "data/${dataset.name}/data_mean.npy"
+data_std: "data/${dataset.name}/data_std.npy"
+save_dir: ???
+n_frames: 32
+context_length: 4
+resolution: 128
+observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
+external_cond_dim: 0
+validation_multiplier: 1
+frame_skip: 1
\ No newline at end of file
diff --git a/configurations/dataset/video_minecraft.yaml b/configurations/dataset/video_minecraft.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b79b93dfa7baa70a66c8a29a3ac79181a13d2e25
--- /dev/null
+++ b/configurations/dataset/video_minecraft.yaml
@@ -0,0 +1,14 @@
+defaults:
+ - base_video
+
+save_dir: data/minecraft_simple_backforward
+n_frames: 16 # TODO: increase later
+resolution: 128
+data_mean: 0.5
+data_std: 0.5
+action_cond_dim: 25
+context_length: 1
+frame_skip: 1
+validation_multiplier: 1
+
+_name: video_minecraft_oasis
\ No newline at end of file
diff --git a/configurations/dataset/video_minecraft_pose.yaml b/configurations/dataset/video_minecraft_pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..00685125867ca84ab52457afa685206f23069846
--- /dev/null
+++ b/configurations/dataset/video_minecraft_pose.yaml
@@ -0,0 +1,14 @@
+defaults:
+ - base_video
+
+save_dir: data/minecraft_simple_backforward
+n_frames: 16 # TODO: increase later
+resolution: 128
+data_mean: 0.5
+data_std: 0.5
+external_cond_dim: 25
+context_length: 1
+frame_skip: 1
+validation_multiplier: 1
+
+_name: video_minecraft_pose
\ No newline at end of file
diff --git a/configurations/experiment/base_experiment.yaml b/configurations/experiment/base_experiment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6884b5e8d8c937671fa88d2a37369234e2cdf30
--- /dev/null
+++ b/configurations/experiment/base_experiment.yaml
@@ -0,0 +1,2 @@
+debug: ${debug} # inherited from configurations/config.yaml
+tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
diff --git a/configurations/experiment/base_pytorch.yaml b/configurations/experiment/base_pytorch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e2beebd81f7997756a6efa28c21d4762aa9a7bd0
--- /dev/null
+++ b/configurations/experiment/base_pytorch.yaml
@@ -0,0 +1,50 @@
+# inherites from base_experiment.yaml
+# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
+
+defaults:
+ - base_experiment
+
+tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
+num_nodes: 1 # number of gpu servers used in large scale distributed training
+
+training:
+ precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
+ compile: False # whether to compile the model with torch.compile
+ lr: 0.001 # learning rate
+ batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
+ max_epochs: 1000 # set to -1 to train forever
+ max_steps: -1 # set to -1 to train forever, will override max_epochs
+ max_time: null # set to something like "00:12:00:00" to enable
+ data:
+ num_workers: 4 # number of CPU threads for data preprocessing.
+ shuffle: True # whether training data will be shuffled
+ optim:
+ accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
+ gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
+ checkpointing:
+ # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
+ every_n_train_steps: 5000 # save a checkpoint every n train steps
+ every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
+ train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
+ enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
+
+validation:
+ precision: 16-mixed
+ compile: False # whether to compile the model with torch.compile
+ batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
+ val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
+ val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
+ limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
+ inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
+ data:
+ num_workers: 4 # number of CPU threads for data preprocessing, for validation.
+ shuffle: False # whether validation data will be shuffled
+
+test:
+ precision: 16-mixed
+ compile: False # whether to compile the model with torch.compile
+ batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
+ limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
+ data:
+ num_workers: 4 # number of CPU threads for data preprocessing, for test.
+ shuffle: False # whether test data will be shuffled
diff --git a/configurations/experiment/exp_pose.yaml b/configurations/experiment/exp_pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c03e0ced423b35976050d3555b55cbc6f4511f85
--- /dev/null
+++ b/configurations/experiment/exp_pose.yaml
@@ -0,0 +1,31 @@
+defaults:
+ - base_pytorch
+
+tasks: [training]
+
+training:
+ lr: 8e-5
+ precision: 16-mixed
+ batch_size: 4
+ max_epochs: -1
+ max_steps: 2000005
+ checkpointing:
+ every_n_train_steps: 2500
+ optim:
+ gradient_clip_val: 1.0
+
+validation:
+ val_every_n_step: 300
+ val_every_n_epoch: null
+ batch_size: 4
+ limit_batch: 1
+
+test:
+ limit_batch: 1
+ batch_size: 1
+
+logging:
+ metrics:
+ # - fvd
+ # - fid
+ # - lpips
diff --git a/configurations/experiment/exp_video.yaml b/configurations/experiment/exp_video.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c03e0ced423b35976050d3555b55cbc6f4511f85
--- /dev/null
+++ b/configurations/experiment/exp_video.yaml
@@ -0,0 +1,31 @@
+defaults:
+ - base_pytorch
+
+tasks: [training]
+
+training:
+ lr: 8e-5
+ precision: 16-mixed
+ batch_size: 4
+ max_epochs: -1
+ max_steps: 2000005
+ checkpointing:
+ every_n_train_steps: 2500
+ optim:
+ gradient_clip_val: 1.0
+
+validation:
+ val_every_n_step: 300
+ val_every_n_epoch: null
+ batch_size: 4
+ limit_batch: 1
+
+test:
+ limit_batch: 1
+ batch_size: 1
+
+logging:
+ metrics:
+ # - fvd
+ # - fid
+ # - lpips
diff --git a/datasets/README.md b/datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3f6be77fc6ecc0c1b44a3c07666b64e6b8f0af35
--- /dev/null
+++ b/datasets/README.md
@@ -0,0 +1,11 @@
+The `datasets` folder is used to contain dataset code or environment code.
+Don't store actual data like images here! For those, please use the `data` folder instead of `datasets`.
+
+Create a folder to create your own pytorch dataset definition. Then, update the `__init__.py`
+at every level to register all datasets.
+
+Each dataset class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/dataset` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9baaadcc3481236c8a9db73ce14e76f9633242a0
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1 @@
+from .video import MinecraftVideoDataset
\ No newline at end of file
diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ff4816e468ebf8c94df2cadd926ceda9a3ff88e
Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/datasets/video/__init__.py b/datasets/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a7b54176eba55b521128f27c9f217075e51c29e
--- /dev/null
+++ b/datasets/video/__init__.py
@@ -0,0 +1,2 @@
+from .minecraft_video_dataset import MinecraftVideoDataset
+from .minecraft_video_dataset_pose import MinecraftVideoPoseDataset
\ No newline at end of file
diff --git a/datasets/video/__pycache__/__init__.cpython-310.pyc b/datasets/video/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1aa24fc2ae29563c777e12646f432c386239a905
Binary files /dev/null and b/datasets/video/__pycache__/__init__.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/base_video_dataset.cpython-310.pyc b/datasets/video/__pycache__/base_video_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f02d5f68cd74f908954bb6ee3cc5f0a5ba4198f4
Binary files /dev/null and b/datasets/video/__pycache__/base_video_dataset.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/dmlab_video_dataset.cpython-310.pyc b/datasets/video/__pycache__/dmlab_video_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46dddd4a727adcb71d873265b6ce7f1dd54ef6fc
Binary files /dev/null and b/datasets/video/__pycache__/dmlab_video_dataset.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/minecraft_video_dataset.cpython-310.pyc b/datasets/video/__pycache__/minecraft_video_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce61a17704af97a3ad812299082f5d7b79a4e5d1
Binary files /dev/null and b/datasets/video/__pycache__/minecraft_video_dataset.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/minecraft_video_dataset_oasis.cpython-310.pyc b/datasets/video/__pycache__/minecraft_video_dataset_oasis.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98476a72df46cebb192953bb36190fed7424bd99
Binary files /dev/null and b/datasets/video/__pycache__/minecraft_video_dataset_oasis.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/minecraft_video_dataset_oasis_filter.cpython-310.pyc b/datasets/video/__pycache__/minecraft_video_dataset_oasis_filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57ff6ff4344f583e7fa535aa51072b64c4c8f8f8
Binary files /dev/null and b/datasets/video/__pycache__/minecraft_video_dataset_oasis_filter.cpython-310.pyc differ
diff --git a/datasets/video/__pycache__/minecraft_video_dataset_pose.cpython-310.pyc b/datasets/video/__pycache__/minecraft_video_dataset_pose.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fea476e780dfea92efc84ef1a99241bafaeb62dc
Binary files /dev/null and b/datasets/video/__pycache__/minecraft_video_dataset_pose.cpython-310.pyc differ
diff --git a/datasets/video/base_video_dataset.py b/datasets/video/base_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..63947af82fa84c81f81a0ec931562616b21f8265
--- /dev/null
+++ b/datasets/video/base_video_dataset.py
@@ -0,0 +1,158 @@
+from typing import Sequence
+import torch
+import random
+import os
+import numpy as np
+import cv2
+from omegaconf import DictConfig
+from torchvision import transforms
+from pathlib import Path
+from abc import abstractmethod, ABC
+import json
+
+
+class BaseVideoDataset(torch.utils.data.Dataset, ABC):
+ """
+ Base class for video datasets. Videos may be of variable length.
+
+ Folder structure of each dataset:
+ - [save_dir] (specified in config, e.g., data/phys101)
+ - /[split] (one per split)
+ - /data_folder_name (e.g., videos)
+ metadata.json
+ """
+
+ def __init__(self, cfg: DictConfig, split: str = "training"):
+ super().__init__()
+ self.cfg = cfg
+ self.split = split
+ self.resolution = cfg.resolution
+ self.external_cond_dim = cfg.external_cond_dim
+ self.n_frames = (
+ cfg.n_frames * cfg.frame_skip
+ if split == "training"
+ else cfg.n_frames * cfg.frame_skip * cfg.validation_multiplier
+ )
+ self.frame_skip = cfg.frame_skip
+ self.save_dir = Path(cfg.save_dir)
+ self.save_dir.mkdir(exist_ok=True, parents=True)
+ self.split_dir = self.save_dir / f"{split}"
+
+ self.metadata_path = self.save_dir / "metadata.json"
+
+ self.data_paths = self.get_data_paths(self.split)
+
+ if self.split == 'training':
+ self.metadata = [1200] * len(self.data_paths) # total 1500 f
+ else:
+ self.metadata = [1] * len(self.data_paths) # total 1500 f
+ # self.clips_per_video = np.clip(np.array(self.metadata[split]) - self.n_frames + 1, a_min=1, a_max=None).astype(
+ # np.int32
+ # )
+ self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
+ np.int32
+ )
+ self.cum_clips_per_video = np.cumsum(self.clips_per_video)
+ self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
+
+ # shuffle but keep the same order for each epoch, so validation sample is diverse yet deterministic
+ random.seed(0)
+ self.idx_remap = list(range(self.__len__()))
+ random.shuffle(self.idx_remap)
+
+ @abstractmethod
+ def download_dataset(self) -> Sequence[int]:
+ """
+ Download dataset from the internet and build it in save_dir
+
+ Returns a list of video lengths
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_data_paths(self, split):
+ """Return a list of data paths (e.g. xxx.mp4) for a given split"""
+ raise NotImplementedError
+
+ def get_data_lengths(self, split):
+ """Return a list of num_frames for each data path (e.g. xxx.mp4) for a given split"""
+ lengths = []
+ for path in self.get_data_paths(split):
+ length = cv2.VideoCapture(str(path)).get(cv2.CAP_PROP_FRAME_COUNT)
+ lengths.append(length)
+ return lengths
+
+ def split_idx(self, idx):
+ video_idx = np.argmax(self.cum_clips_per_video > idx)
+ frame_idx = idx - np.pad(self.cum_clips_per_video, (1, 0))[video_idx]
+ return video_idx, frame_idx
+
+ @staticmethod
+ def load_video(path: Path):
+ """
+ Load video from a path
+ :param filename: path to the video
+ :return: video as a numpy array
+ """
+
+ cap = cv2.VideoCapture(str(path))
+
+ frames = []
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(frame)
+ else:
+ break
+
+ cap.release()
+ frames = np.stack(frames, dtype=np.uint8)
+ return np.transpose(frames, (0, 3, 1, 2)) # (T, C, H, W)
+
+ @staticmethod
+ def load_image(filename: Path):
+ """
+ Load image from a path
+ :param filename: path to the image
+ :return: image as a numpy array
+ """
+ image = cv2.imread(str(filename))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ return np.transpose(image, (2, 0, 1))
+
+ def __len__(self):
+ return self.clips_per_video.sum()
+
+ def __getitem__(self, idx):
+ idx = self.idx_remap[idx]
+ video_idx, frame_idx = self.split_idx(idx)
+ video_path = self.data_paths[video_idx]
+ video = self.load_video(video_path)[frame_idx : frame_idx + self.n_frames]
+
+ pad_len = self.n_frames - len(video)
+
+ nonterminal = np.ones(self.n_frames)
+ if len(video) < self.n_frames:
+ video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
+ nonterminal[-pad_len:] = 0
+
+ video = torch.from_numpy(video / 256.0).float()
+ video = self.transform(video)
+
+ if self.external_cond_dim:
+ external_cond = np.load(
+ # pylint: disable=no-member
+ self.condition_dir
+ / f"{video_path.name.replace('.mp4', '.npy')}"
+ )
+ if len(external_cond) < self.n_frames:
+ external_cond = np.pad(external_cond, ((0, pad_len),))
+ external_cond = torch.from_numpy(external_cond).float()
+ return (
+ video[:: self.frame_skip],
+ external_cond[:: self.frame_skip],
+ nonterminal[:: self.frame_skip],
+ )
+ else:
+ return video[:: self.frame_skip], nonterminal[:: self.frame_skip]
diff --git a/datasets/video/minecraft_video_dataset.py b/datasets/video/minecraft_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..da640386269081f3d6a7df998a5281929d38864e
--- /dev/null
+++ b/datasets/video/minecraft_video_dataset.py
@@ -0,0 +1,262 @@
+import os
+import io
+import tarfile
+import numpy as np
+import torch
+from typing import Sequence, Mapping
+from omegaconf import DictConfig
+from pytorchvideo.data.encoded_video import EncodedVideo
+import random
+
+from .base_video_dataset import BaseVideoDataset
+
+
+
+
+ACTION_KEYS = [
+ "inventory",
+ "ESC",
+ "hotbar.1",
+ "hotbar.2",
+ "hotbar.3",
+ "hotbar.4",
+ "hotbar.5",
+ "hotbar.6",
+ "hotbar.7",
+ "hotbar.8",
+ "hotbar.9",
+ "forward",
+ "back",
+ "left",
+ "right",
+ "cameraY",
+ "cameraX",
+ "jump",
+ "sneak",
+ "sprint",
+ "swapHands",
+ "attack",
+ "use",
+ "pickItem",
+ "drop",
+]
+
+def convert_action_space(actions):
+ vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
+ vec_25[actions[:,0]==1, 11] = 1
+ vec_25[actions[:,0]==2, 12] = 1
+ vec_25[actions[:,4]==11, 16] = -1
+ vec_25[actions[:,4]==13, 16] = 1
+ vec_25[actions[:,3]==11, 15] = -1
+ vec_25[actions[:,3]==13, 15] = 1
+ vec_25[actions[:,5]==6, 24] = 1
+ vec_25[actions[:,5]==1, 24] = 1
+ vec_25[actions[:,1]==1, 13] = 1
+ vec_25[actions[:,1]==2, 14] = 1
+ vec_25[actions[:,7]==1, 2] = 1
+ return vec_25
+
+# Dataset class
+class MinecraftVideoDataset(BaseVideoDataset):
+ """
+ Minecraft video dataset for training and validation.
+
+ Args:
+ cfg (DictConfig): Configuration object.
+ split (str): Dataset split ("training" or "validation").
+ """
+ def __init__(self, cfg: DictConfig, split: str = "training"):
+ if split == "test":
+ split = "validation"
+ super().__init__(cfg, split)
+ self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
+ self.use_plucker = cfg.use_plucker
+ self.condition_similar_length = cfg.condition_similar_length
+ self.customized_validation = cfg.customized_validation
+ self.angle_range = cfg.angle_range
+ self.pos_range = cfg.pos_range
+ self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder
+ self.training_dropout = 0.1
+ self.sample_more_place = getattr(cfg, "sample_more_place", False)
+ self.within_context = getattr(cfg, "within_context", False)
+ self.sample_more_event = getattr(cfg, "sample_more_event", False)
+ self.causal_frame = getattr(cfg, "causal_frame", False)
+
+ def get_data_paths(self, split: str):
+ """
+ Retrieve all video file paths for the given split.
+
+ Args:
+ split (str): Dataset split ("training" or "validation").
+
+ Returns:
+ List[Path]: List of video file paths.
+ """
+ data_dir = self.save_dir / split
+ paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
+ if not paths:
+ sub_dirs = os.listdir(data_dir)
+ for sub_dir in sub_dirs:
+ sub_path = data_dir / sub_dir
+ paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
+ return paths
+
+ def download_dataset(self):
+ pass
+
+ def __getitem__(self, idx: int):
+ """
+ Retrieve a single data sample by index.
+
+ Args:
+ idx (int): Index of the data sample.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps.
+ """
+ max_retries = 1000
+ for _ in range(max_retries):
+ try:
+ return self.load_data(idx)
+ except Exception as e:
+ print(f"Retrying due to error: {e}")
+ idx = (idx + 1) % len(self)
+
+ def load_data(self, idx):
+ idx = self.idx_remap[idx]
+ file_idx, frame_idx = self.split_idx(idx)
+ action_path = self.data_paths[file_idx]
+ video_path = self.data_paths[file_idx]
+
+ action_path = video_path.with_suffix(".npz")
+ actions_pool = np.load(action_path)['actions']
+ poses_pool = np.load(action_path)['poses']
+
+
+ poses_pool[0,1] = poses_pool[1,1] # wrong first in place
+
+ assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
+
+
+ if len(poses_pool) < len(actions_pool):
+ poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
+
+ actions_pool = convert_action_space(actions_pool)
+ video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
+
+ frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
+
+ if self.split == "validation":
+ frame_idx = 240
+
+ if self.sample_more_place and self.split == "training":
+ if random.uniform(0, 1) > 0.5:
+ place_mask = (actions_pool[:,24]==1)
+ place_mask[:100] = 0
+ valid_indices = np.where(place_mask)[0]
+ random_index = np.random.choice(valid_indices)
+ frame_idx = random_index - random.randint(1, self.n_frames-1)
+
+ total_frame = video_raw.duration.numerator
+ fps = 10 # video_raw.duration.denominator
+ total_frame = total_frame * fps / video_raw.duration.denominator
+ video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
+ video = video.permute(1, 2, 3, 0).numpy()
+
+ if self.split != "validation" and 'degrees' in np.load(action_path).keys():
+ degrees = np.load(action_path)['degrees']
+ actions_pool[:,16] *= degrees
+
+ actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
+
+ poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
+ pad_len = self.n_frames - len(video)
+ poses_pool[:,:3] -= poses[:1,:3]
+ poses_pool[:,-1] = -poses_pool[:,-1]
+ poses_pool[:,3:] %= 360
+
+ poses[:,:3] -= poses[:1,:3] # do not normalize angle
+ poses[:,-1] = -poses[:,-1]
+ poses[:,3:] %= 360
+
+ assert len(video) >= self.n_frames, f"{video_path}"
+
+ if self.split == "training" and self.condition_similar_length>0:
+ if random.uniform(0, 1) > self.training_dropout:
+ refer_frame_dis = poses[:,None] - poses_pool[None,:]
+ refer_frame_dis = np.abs(refer_frame_dis)
+ refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180]
+ valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \
+ ((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1))
+ ).sum(0)
+ valid_index[:100] = 0 # mute bad initial scene
+
+ if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0:
+ valid_index[frame_idx:] = 0
+
+ mask = valid_index >= 1
+ mask[0] = False
+ candidate_indices = np.argwhere(mask)
+
+ mask2 = valid_index >= 0
+ mask2[0] = False
+
+ count = min(self.condition_similar_length, candidate_indices.shape[0])
+ selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0]
+
+ if count < self.condition_similar_length:
+ candidate_indices2 = np.argwhere(mask2)
+ selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0]
+ selected_indices = np.concatenate([selected_indices, selected_indices2])
+
+ if self.sample_more_event:
+ if random.uniform(0, 1) > 0.3:
+ valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0]
+ if len(valid_idx) > self.condition_similar_length //2:
+ valid_idx = valid_idx[-self.condition_similar_length //2:]
+
+ if len(valid_idx) > 0:
+ selected_indices[-len(valid_idx):] = valid_idx + 4
+
+ else:
+ selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx)
+
+ video_pool = []
+ for si in selected_indices:
+ video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0))
+
+ video_pool = np.stack(video_pool)
+ video = np.concatenate([video, video_pool])
+ actions = np.concatenate([actions, actions_pool[selected_indices]])
+ poses = np.concatenate([poses, poses_pool[selected_indices]])
+
+ timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices])
+
+ else:
+ timestep = np.array(list(range(self.n_frames)))
+
+ video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
+
+ if self.split == "validation" and not self.customized_validation:
+ num_frame = actions.shape[0]
+
+ actions[:] = 0
+ actions[:,16] = 1
+ poses[:] = 0
+ for ff in range(1, num_frame):
+ poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
+
+ if self.within_context:
+ actions[:] = 0
+ actions[:self.n_frames//2+1,16] = 1
+ actions[self.n_frames//2+1:,16] = -1
+ poses[:] = 0
+ for ff in range(1, num_frame):
+ poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
+
+ return (
+ video[:: self.frame_skip],
+ actions[:: self.frame_skip],
+ poses[:: self.frame_skip],
+ timestep
+ )
\ No newline at end of file
diff --git a/datasets/video/minecraft_video_dataset_oasis_filter.py b/datasets/video/minecraft_video_dataset_oasis_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f959e33dbf847d0f3c06c4cd5b50435d3a3a3470
--- /dev/null
+++ b/datasets/video/minecraft_video_dataset_oasis_filter.py
@@ -0,0 +1,99 @@
+import torch
+from typing import Sequence
+import numpy as np
+import io
+from omegaconf import DictConfig
+from tqdm import tqdm
+
+from typing import Mapping, Sequence
+import os
+import math
+from packaging import version as pver
+from PIL import Image
+import random
+import shutil
+import os
+from pathlib import Path
+import traceback
+
+class OASISMinecraftVideoFilterDataset(torch.utils.data.Dataset):
+ """
+ Minecraft dataset
+ """
+
+ def __init__(self, source_dir, target_dir, split):
+ self.source_dir = Path(source_dir)
+ self.split_dir = self.source_dir / f"{split}"
+ self.data_paths = self.get_data_paths(split)
+ self.target_dir = Path(target_dir) / f"{split}"
+ self.target_dir.mkdir(exist_ok=True, parents=True)
+
+ def get_data_paths(self, split):
+ data_dir = self.source_dir / split
+ paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
+
+ if len(paths) == 0:
+ sub_path = os.listdir(data_dir)
+ for sp in sub_path:
+ data_dir = self.source_dir / split / sp
+ paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
+ return paths
+
+ def __len__(self):
+ return len(self.data_paths)
+
+ def __getitem__(self, idx):
+
+ return self.sub_get(idx)
+ # try:
+ # return self.sub_get(idx)
+ # except Exception as e:
+ # traceback.print_exc()
+ # # return self.sub_get(0)
+
+
+ def sub_get(self, idx):
+ action_path = self.data_paths[idx]
+ video_path = self.data_paths[idx]
+
+ action_path = video_path.with_suffix(".npz")
+ actions_pool = np.load(action_path)['actions']
+ poses_pool = np.load(action_path)['poses']
+
+ poses_pool[0,1] = poses_pool[1,1] # wrong first in place
+
+ print(poses_pool.shape)
+
+ if poses_pool[:,1].max() - poses_pool[:,1].min() < 2:
+ target_action_path = self.target_dir / action_path.parent.name / action_path.name
+ target_video_path = self.target_dir / video_path.parent.name / video_path.name
+ target_action_path.parent.mkdir(exist_ok=True, parents=True)
+ target_video_path.parent.mkdir(exist_ok=True, parents=True)
+
+ try:
+ shutil.copy2(action_path, target_action_path)
+ shutil.copy2(video_path, target_video_path)
+ except:
+ import pdb;pdb.set_trace()
+
+ return poses_pool[:10]
+
+
+
+if __name__ == "__main__":
+ import torch
+ from unittest.mock import MagicMock
+ import tqdm
+
+ cfg = MagicMock()
+ cfg.resolution = 64
+ cfg.external_cond_dim = 0
+ cfg.n_frames = 64
+ cfg.save_dir = "data/minecraft"
+ cfg.validation_multiplier = 1
+
+ dataset = MinecraftVideoDataset(cfg, "training")
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
+
+ for batch in tqdm.tqdm(dataloader):
+ pass
diff --git a/datasets/video/minecraft_video_dataset_pose.py b/datasets/video/minecraft_video_dataset_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0b85c18a8a9fe5140d074cb910b8cb01cd6497c
--- /dev/null
+++ b/datasets/video/minecraft_video_dataset_pose.py
@@ -0,0 +1,421 @@
+import torch
+from typing import Sequence
+import numpy as np
+import io
+import tarfile
+from pytorchvideo.data.encoded_video import EncodedVideo
+from omegaconf import DictConfig
+from tqdm import tqdm
+
+from .base_video_dataset import BaseVideoDataset
+from typing import Mapping, Sequence
+import os
+import math
+from packaging import version as pver
+from PIL import Image
+import random
+
+def euler_to_rotation_matrix(pitch, yaw):
+ """
+ Convert euler angles (pitch, yaw) to a 3x3 rotation matrix.
+ pitch: rotation around x-axis (in radians)
+ yaw: rotation around y-axis (in radians)
+ """
+ # Rotation matrix around x-axis (pitch)
+ R_x = np.array([
+ [1, 0, 0],
+ [0, math.cos(pitch), -math.sin(pitch)],
+ [0, math.sin(pitch), math.cos(pitch)]
+ ])
+
+ # Rotation matrix around y-axis (yaw)
+ R_y = np.array([
+ [math.cos(yaw), 0, math.sin(yaw)],
+ [0, 1, 0],
+ [-math.sin(yaw), 0, math.cos(yaw)]
+ ])
+
+ # Combined rotation matrix
+ R = np.dot(R_y, R_x)
+ return R
+
+def custom_meshgrid(*args):
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
+ return torch.meshgrid(*args)
+ else:
+ return torch.meshgrid(*args, indexing='ij')
+
+def camera_to_world_to_world_to_camera(camera_to_world):
+ """
+ Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation.
+ """
+ # Extract rotation (R) and translation (T)
+ R = camera_to_world[:3, :3]
+ T = camera_to_world[:3, 3]
+
+ # Calculate World-to-Camera (inverse) matrix
+ world_to_camera = np.eye(4)
+
+ # The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation
+ world_to_camera[:3, :3] = R.T
+
+ # The translation part is the negative of the rotated translation
+ world_to_camera[:3, 3] = -np.dot(R.T, T)
+
+ return world_to_camera
+
+def euler_to_camera_to_world_matrix(pose):
+
+ x, y, z, pitch, yaw = pose
+ # Convert pitch and yaw to radians
+ pitch = math.radians(pitch)
+ yaw = math.radians(yaw)
+
+ # Get the rotation matrix from Euler angles
+ R = euler_to_rotation_matrix(pitch, yaw)
+
+ # Create the 4x4 transformation matrix (rotation + translation)
+ camera_to_world = np.eye(4)
+
+ # Set the rotation part (upper 3x3)
+ camera_to_world[:3, :3] = R
+
+ # Set the translation part (last column)
+ camera_to_world[:3, 3] = [x, y, z]
+
+ return camera_to_world
+
+def tensor_to_gif(tensor, output_path, fps=10):
+ """
+ Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF.
+
+ Args:
+ tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255].
+ output_path (str): Path to save the output GIF.
+ fps (int): Frames per second for the GIF.
+ """
+ # Ensure the tensor is in [0, 255] range
+ if tensor.max() <= 1.0:
+ tensor = (tensor * 255).byte()
+ else:
+ tensor = tensor.byte()
+
+ # Convert tensor to numpy array and rearrange to (F, H, W, 3)
+ frames = tensor.permute(0, 2, 3, 1).cpu().numpy()
+
+ # Convert frames to PIL Images
+ pil_frames = [Image.fromarray(frame) for frame in frames]
+
+ # Save as GIF
+ pil_frames[0].save(
+ output_path,
+ save_all=True,
+ append_images=pil_frames[1:],
+ duration=int(1000 / fps),
+ loop=0
+ )
+
+def get_relative_pose(cam_params, zero_first_frame_scale):
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
+ source_cam_c2w = abs_c2ws[0]
+ if zero_first_frame_scale:
+ cam_to_origin = 0
+ else:
+ cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
+ target_cam_c2w = np.array([
+ [1, 0, 0, 0],
+ [0, 1, 0, -cam_to_origin],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ])
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
+ ret_poses = np.array(ret_poses, dtype=np.float32)
+ return ret_poses
+
+def ray_condition(K, c2w, H, W, device):
+ # c2w: B, V, 4, 4
+ # K: B, V, 4
+
+ B = K.shape[0]
+
+ j, i = custom_meshgrid(
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
+ )
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
+
+ zs = torch.ones_like(i) # [B, HxW]
+ xs = (i - cx) / fx * zs
+ ys = (j - cy) / fy * zs
+ zs = zs.expand_as(ys)
+
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
+
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
+ rays_o = c2w[..., :3, 3] # B, V, 3
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
+ # c2w @ dirctions
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
+
+ return plucker
+
+class Camera(object):
+ def __init__(self, entry, focal_length=0.35):
+ self.fx = focal_length # 0.35 correspond to 110 fov
+ self.fy = focal_length*640/360
+ self.cx = 0.5
+ self.cy = 0.5
+ self.c2w_mat = euler_to_camera_to_world_matrix(entry)
+ self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat))
+
+
+ACTION_KEYS = [
+ "inventory",
+ "ESC",
+ "hotbar.1",
+ "hotbar.2",
+ "hotbar.3",
+ "hotbar.4",
+ "hotbar.5",
+ "hotbar.6",
+ "hotbar.7",
+ "hotbar.8",
+ "hotbar.9",
+ "forward",
+ "back",
+ "left",
+ "right",
+ "cameraY",
+ "cameraX",
+ "jump",
+ "sneak",
+ "sprint",
+ "swapHands",
+ "attack",
+ "use",
+ "pickItem",
+ "drop",
+]
+
+def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
+ actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
+ for i, current_actions in enumerate(actions):
+ for j, action_key in enumerate(ACTION_KEYS):
+ if action_key.startswith("camera"):
+ if action_key == "cameraX":
+ value = current_actions["camera"][0]
+ elif action_key == "cameraY":
+ value = current_actions["camera"][1]
+ else:
+ raise ValueError(f"Unknown camera action key: {action_key}")
+ max_val = 20
+ bin_size = 0.5
+ num_buckets = int(max_val / bin_size)
+ value = (value - num_buckets) / num_buckets
+ assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
+ else:
+ value = current_actions[action_key]
+ assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
+ actions_one_hot[i, j] = value
+
+ return actions_one_hot
+
+def simpletomulti(actions):
+ vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
+ vec_25[actions==1, 11] = 1
+ vec_25[actions==2, 16] = -1
+ vec_25[actions==3, 16] = 1
+ vec_25[actions==4, 15] = -1
+ vec_25[actions==5, 15] = 1
+ return vec_25
+
+def simpletomulti2(actions):
+ vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
+ vec_25[actions[:,0]==1, 11] = 1
+ vec_25[actions[:,0]==2, 12] = 1
+ vec_25[actions[:,4]==11, 16] = -1
+ vec_25[actions[:,4]==13, 16] = 1
+ vec_25[actions[:,3]==11, 15] = -1
+ vec_25[actions[:,3]==13, 15] = 1
+ vec_25[actions[:,5]==6, 24] = 1
+ vec_25[actions[:,5]==1, 24] = 1
+ vec_25[actions[:,1]==1, 13] = 1
+ vec_25[actions[:,1]==2, 14] = 1
+ vec_25[actions[:,7]==1, 2] = 1
+ return vec_25
+
+class MinecraftVideoPoseDataset(BaseVideoDataset):
+ """
+ Minecraft dataset
+ """
+
+ def __init__(self, cfg: DictConfig, split: str = "training"):
+ if split == "test":
+ split = "validation"
+ super().__init__(cfg, split)
+
+ if hasattr(cfg, "n_frames_valid") and split == "validation":
+ self.n_frames = cfg.n_frames_valid
+
+ def get_data_paths(self, split):
+ data_dir = self.save_dir / split
+ paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
+
+ if len(paths) == 0:
+ sub_path = os.listdir(data_dir)
+ for sp in sub_path:
+ data_dir = self.save_dir / split / sp
+ paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
+ return paths
+
+ def get_data_lengths(self, split):
+ lengths = [300] * len(self.get_data_paths(split))
+ return lengths
+
+ def download_dataset(self) -> Sequence[int]:
+ from internetarchive import download
+
+ part_suffixes = [
+ "aa",
+ "ab",
+ "ac",
+ "ad",
+ "ae",
+ "af",
+ "ag",
+ "ah",
+ "ai",
+ "aj",
+ "ak",
+ ]
+ for part_suffix in part_suffixes:
+ identifier = f"minecraft_marsh_dataset_{part_suffix}"
+ file_name = f"minecraft.tar.part{part_suffix}"
+ download(identifier, file_name, destdir=self.save_dir, verbose=True)
+
+ combined_bytes = io.BytesIO()
+ for part_suffix in part_suffixes:
+ identifier = f"minecraft_marsh_dataset_{part_suffix}"
+ file_name = f"minecraft.tar.part{part_suffix}"
+ part_file = self.save_dir / identifier / file_name
+ with open(part_file, "rb") as part:
+ combined_bytes.write(part.read())
+ combined_bytes.seek(0)
+ with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive:
+ combined_archive.extractall(self.save_dir)
+ (self.save_dir / "minecraft/test").rename(self.save_dir / "validation")
+ (self.save_dir / "minecraft/train").rename(self.save_dir / "training")
+ (self.save_dir / "minecraft").rmdir()
+ for part_suffix in part_suffixes:
+ identifier = f"minecraft_marsh_dataset_{part_suffix}"
+ file_name = f"minecraft.tar.part{part_suffix}"
+ part_file = self.save_dir / identifier / file_name
+ part_file.rmdir()
+
+ def __getitem__(self, idx):
+ # return self.load_data(idx)
+
+ max_retries = 1000
+ for mr in range(max_retries):
+ try:
+ return self.load_data(idx)
+ except Exception as e:
+ print(f"{mr} Error: {e}")
+ # idx = self.idx_remap[idx]
+ # file_idx, frame_idx = self.split_idx(idx)
+ # video_path = self.data_paths[file_idx]
+ # os.remove(video_path)
+ idx = (idx + 1) % self.__len__()
+
+ def load_data(self, idx):
+ idx = self.idx_remap[idx]
+ file_idx, frame_idx = self.split_idx(idx)
+ action_path = self.data_paths[file_idx]
+ video_path = self.data_paths[file_idx]
+
+ action_path = video_path.with_suffix(".npz")
+ actions_pool = np.load(action_path)['actions']
+ poses_pool = np.load(action_path)['poses']
+
+ poses_pool[0,1] = poses_pool[1,1] # wrong first in place
+
+ assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
+
+ if len(poses_pool) < len(actions_pool):
+ poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
+
+ actions_pool = simpletomulti2(actions_pool)
+ video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
+
+ frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
+
+ if self.split == "validation":
+ frame_idx = 240
+
+ total_frame = video_raw.duration.numerator
+ fps = 10 # video_raw.duration.denominator
+ total_frame = total_frame * fps / video_raw.duration.denominator
+ video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
+
+ video = video.permute(1, 2, 3, 0).numpy()
+
+ if self.split != "validation" and 'degrees' in np.load(action_path).keys():
+ degrees = np.load(action_path)['degrees']
+ actions_pool[:,16] *= degrees
+
+ actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, )
+
+ poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
+ pad_len = self.n_frames - len(video)
+ poses_pool[:,:3] -= poses[:1,:3]
+ # poses_pool[:,3:] = -poses_pool[:,3:]
+ poses_pool[:,-1] = -poses_pool[:,-1]
+ poses_pool[:,3:] %= 360
+
+ poses[:,:3] -= poses[:1,:3] # do not normalize angle
+ # poses[:,3:] = -poses[:,3:]
+ poses[:,-1] = -poses[:,-1]
+ poses[:,3:] %= 360
+
+ nonterminal = np.ones(self.n_frames)
+ if len(video) < self.n_frames:
+ video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
+ actions = np.pad(actions, ((0, pad_len),))
+ poses = np.pad(actions, ((0, pad_len),))
+ nonterminal[-pad_len:] = 0
+
+ video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
+
+ return (
+ video[:: self.frame_skip],
+ actions[:: self.frame_skip],
+ poses[:: self.frame_skip]
+ )
+
+
+if __name__ == "__main__":
+ import torch
+ from unittest.mock import MagicMock
+ import tqdm
+
+ cfg = MagicMock()
+ cfg.resolution = 64
+ cfg.external_cond_dim = 0
+ cfg.n_frames = 64
+ cfg.save_dir = "data/minecraft"
+ cfg.validation_multiplier = 1
+
+ dataset = MinecraftVideoDataset(cfg, "training")
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
+
+ for batch in tqdm.tqdm(dataloader):
+ pass
diff --git a/experiments/README.md b/experiments/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d43524825bc1877f413e6dc262cd2057bd00ab2
--- /dev/null
+++ b/experiments/README.md
@@ -0,0 +1,19 @@
+# experiments
+
+`experiments` folder contains code of experiments. Each file in the experiment folder represents a certain type of
+benchmark specific to a project. Such experiment can be instantiated with a certain dataset and a certain algorithm.
+
+You should create a new `.py` file for your experiment,
+inherent from any suitable base classes in `experiments/exp_base.py`,
+and then register your new experiment in `experiments/__init__.py`.
+
+You run an experiment by running `python -m main [options]` in the root directory of the
+project. You should not log any data in this folder, but storing them under `outputs` under root project
+directory.
+
+This folder is only intend to contain formal experiments. For debug code and unit tests, put them under `debug` folder.
+For scripts that's not meant to be an experiment please use `scripts` folder.
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/experiments/__init__.py b/experiments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c104df75f32f8140d660a8f6d4c22ea1156b680
--- /dev/null
+++ b/experiments/__init__.py
@@ -0,0 +1,35 @@
+from typing import Optional, Union
+from omegaconf import DictConfig
+import pathlib
+from lightning.pytorch.loggers.wandb import WandbLogger
+
+from .exp_base import BaseExperiment
+from .exp_video import VideoPredictionExperiment
+from .exp_pose import PoseExperiment
+
+# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
+exp_registry = dict(
+ exp_video=VideoPredictionExperiment,
+ exp_pose=PoseExperiment
+)
+
+
+def build_experiment(
+ cfg: DictConfig,
+ logger: Optional[WandbLogger] = None,
+ ckpt_path: Optional[Union[str, pathlib.Path]] = None,
+) -> BaseExperiment:
+ """
+ Build an experiment instance based on registry
+ :param cfg: configuration file
+ :param logger: optional logger for the experiment
+ :param ckpt_path: optional checkpoint path for saving and loading
+ :return:
+ """
+ if cfg.experiment._name not in exp_registry:
+ raise ValueError(
+ f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
+ "Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
+ )
+
+ return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)
diff --git a/experiments/__pycache__/__init__.cpython-310.pyc b/experiments/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92d15aea10e54726b71cbdc27c4abac8c4986fa6
Binary files /dev/null and b/experiments/__pycache__/__init__.cpython-310.pyc differ
diff --git a/experiments/__pycache__/exp_base.cpython-310.pyc b/experiments/__pycache__/exp_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6836333efcd16c010e394a6cd0410993c2f52cf1
Binary files /dev/null and b/experiments/__pycache__/exp_base.cpython-310.pyc differ
diff --git a/experiments/__pycache__/exp_planning.cpython-310.pyc b/experiments/__pycache__/exp_planning.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e79ec7637b63688559201233cd04a00bc7505ab5
Binary files /dev/null and b/experiments/__pycache__/exp_planning.cpython-310.pyc differ
diff --git a/experiments/__pycache__/exp_pose.cpython-310.pyc b/experiments/__pycache__/exp_pose.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3b3bb5395b351a3d82642d52eb47ed38ffc5072
Binary files /dev/null and b/experiments/__pycache__/exp_pose.cpython-310.pyc differ
diff --git a/experiments/__pycache__/exp_video.cpython-310.pyc b/experiments/__pycache__/exp_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50a0a14598bd3e499dc4fcd4391c4346435e4118
Binary files /dev/null and b/experiments/__pycache__/exp_video.cpython-310.pyc differ
diff --git a/experiments/exp_base.py b/experiments/exp_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c7da8efa4300dba7027c0e1fcff8f24270db052
--- /dev/null
+++ b/experiments/exp_base.py
@@ -0,0 +1,463 @@
+"""
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
+template [repo](https://github.com/buoyancy99/research-template).
+By its MIT license, you must keep the above sentence in `README.md`
+and the `LICENSE` file to credit the author.
+"""
+
+from abc import ABC, abstractmethod
+from typing import Optional, Union, Literal, List, Dict
+import pathlib
+import os
+
+import hydra
+import torch
+from lightning.pytorch.strategies.ddp import DDPStrategy
+
+import lightning.pytorch as pl
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.utilities import rank_zero_info
+
+from omegaconf import DictConfig
+
+from utils.print_utils import cyan
+from utils.distributed_utils import is_rank_zero
+from safetensors.torch import load_model
+from pathlib import Path
+
+
+torch.set_float32_matmul_precision("high")
+
+def load_custom_checkpoint(algo, optimizer, checkpoint_path):
+ if not checkpoint_path:
+ rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
+ return None
+
+ if not isinstance(checkpoint_path, Path):
+ checkpoint_path = Path(checkpoint_path)
+
+ if checkpoint_path.suffix == ".pt":
+ ckpt = torch.load(checkpoint_path, weights_only=True)
+ algo.load_state_dict(ckpt, strict=False)
+ elif checkpoint_path.suffix == ".ckpt":
+ ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
+ elif checkpoint_path.suffix == ".safetensors":
+ load_model(algo, checkpoint_path, strict=False)
+ elif os.path.isdir(checkpoint_path):
+ ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
+ if not ckpt_files:
+ raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
+ selected_ckpt = max(ckpt_files)
+ selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
+ print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
+
+ ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
+
+ rank_zero_info("Model weights loaded.")
+
+class BaseExperiment(ABC):
+ """
+ Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
+ flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
+ """
+
+ # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
+ compatible_algorithms: Dict = NotImplementedError
+
+ def __init__(
+ self,
+ root_cfg: DictConfig,
+ logger: Optional[WandbLogger] = None,
+ ckpt_path: Optional[Union[str, pathlib.Path]] = None,
+ ) -> None:
+ """
+ Constructor
+
+ Args:
+ cfg: configuration file that contains everything about the experiment
+ logger: a pytorch-lightning WandbLogger instance
+ ckpt_path: an optional path to saved checkpoint
+ """
+ super().__init__()
+ self.root_cfg = root_cfg
+ self.cfg = root_cfg.experiment
+ self.debug = root_cfg.debug
+ self.logger = logger
+ self.ckpt_path = ckpt_path
+ self.algo = None
+ self.customized_load = root_cfg.customized_load
+ self.load_vae = root_cfg.load_vae
+ self.load_t_to_r = root_cfg.load_t_to_r
+ self.zero_init_gate=root_cfg.zero_init_gate
+ self.only_tune_refer = root_cfg.only_tune_refer
+ self.vae_path = root_cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
+ self.pose_predictor_path = root_cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
+
+ def _build_algo(self):
+ """
+ Build the lightning module
+ :return: a pytorch-lightning module to be launched
+ """
+ algo_name = self.root_cfg.algorithm._name
+ if algo_name not in self.compatible_algorithms:
+ raise ValueError(
+ f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
+ "Make sure you define compatible_algorithms correctly and make sure that each key has "
+ "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
+ )
+ return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
+
+ def exec_task(self, task: str) -> None:
+ """
+ Executing a certain task specified by string. Each task should be a stage of experiment.
+ In most computer vision / nlp applications, tasks should be just train and test.
+ In reinforcement learning, you might have more stages such as collecting dataset etc
+
+ Args:
+ task: a string specifying a task implemented for this experiment
+ """
+ if hasattr(self, task) and callable(getattr(self, task)):
+ if is_rank_zero:
+ print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
+ getattr(self, task)()
+ else:
+ raise ValueError(
+ f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
+ )
+
+ def exec_interactive(self, task: str) -> None:
+ """
+ Executing a certain task specified by string. Each task should be a stage of experiment.
+ In most computer vision / nlp applications, tasks should be just train and test.
+ In reinforcement learning, you might have more stages such as collecting dataset etc
+
+ Args:
+ task: a string specifying a task implemented for this experiment
+ """
+ if hasattr(self, task) and callable(getattr(self, task)):
+ if is_rank_zero:
+ print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
+ return getattr(self, task)()
+ else:
+ raise ValueError(
+ f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
+ )
+
+class BaseLightningExperiment(BaseExperiment):
+ """
+ Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are
+ simply models, datasets and train loop.
+ """
+
+ # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
+ compatible_algorithms: Dict = NotImplementedError
+
+ # each key has to be a yaml file under '[project_root]/configurations/dataset' without .yaml suffix
+ compatible_datasets: Dict = NotImplementedError
+
+ def _build_trainer_callbacks(self):
+ callbacks = []
+ if self.logger:
+ callbacks.append(LearningRateMonitor("step", True))
+
+ def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ train_dataset = self._build_dataset("training")
+ shuffle = (
+ False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
+ )
+ if train_dataset:
+ return torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=self.cfg.training.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ validation_dataset = self._build_dataset("validation")
+ shuffle = (
+ False
+ if isinstance(validation_dataset, torch.utils.data.IterableDataset)
+ else self.cfg.validation.data.shuffle
+ )
+ if validation_dataset:
+ return torch.utils.data.DataLoader(
+ validation_dataset,
+ batch_size=self.cfg.validation.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ test_dataset = self._build_dataset("test")
+ shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
+ if test_dataset:
+ return torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=self.cfg.test.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def training(self) -> None:
+ """
+ All training happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.training.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+ if self.logger:
+ callbacks.append(LearningRateMonitor("step", True))
+ if "checkpointing" in self.cfg.training:
+ callbacks.append(
+ ModelCheckpoint(
+ pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
+ **self.cfg.training.checkpointing,
+ )
+ )
+
+ # TODO do not upload checkpoint to wandb
+
+ # trainer = pl.Trainer(
+ # accelerator="auto",
+ # logger=self.logger if self.logger else False,
+ # devices=torch.cuda.device_count(),
+ # num_nodes=self.cfg.num_nodes,
+ # strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
+ # callbacks=callbacks,
+ # gradient_clip_val=self.cfg.training.optim.gradient_clip_val,
+ # val_check_interval=self.cfg.validation.val_every_n_step,
+ # limit_val_batches=self.cfg.validation.limit_batch,
+ # check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch,
+ # accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches,
+ # precision=self.cfg.training.precision,
+ # detect_anomaly=False, # self.cfg.debug,
+ # num_sanity_val_steps=int(self.cfg.debug),
+ # max_epochs=self.cfg.training.max_epochs,
+ # max_steps=self.cfg.training.max_steps,
+ # max_time=self.cfg.training.max_time,
+ # )
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ devices="auto", # 自动选择设备
+ strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
+ logger=self.logger or False, # 简化写法
+ callbacks=callbacks,
+ gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
+ val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
+ limit_val_batches=self.cfg.validation.limit_batch,
+ check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
+ accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
+ precision=self.cfg.training.precision or 32, # 默认32位精度
+ detect_anomaly=False, # 默认关闭异常检测
+ num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
+ max_epochs=self.cfg.training.max_epochs,
+ max_steps=self.cfg.training.max_steps,
+ max_time=self.cfg.training.max_time
+ )
+
+
+ if self.customized_load:
+ if self.load_vae:
+ load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
+ load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
+ else:
+ load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
+
+ if self.load_t_to_r:
+ param_list = []
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 't_' in name and 't_embedder' not in name:
+ print(name)
+ param_list.append(para)
+
+ it = 0
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 'r_' in name:
+ para.requires_grad_(False)
+ try:
+ para.copy_(param_list[it].detach().cpu())
+ except:
+ import pdb;pdb.set_trace()
+ para.requires_grad_(True)
+ it += 1
+
+ if self.zero_init_gate:
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 'r_adaLN_modulation' in name:
+ para.requires_grad_(False)
+ para[2*1024:3*1024] = 0
+ para[5*1024:6*1024] = 0
+ para.requires_grad_(True)
+
+ if self.only_tune_refer:
+ for name, para in self.algo.diffusion_model.named_parameters():
+ para.requires_grad_(False)
+ if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
+ para.requires_grad_(True)
+
+ trainer.fit(
+ self.algo,
+ train_dataloaders=self._build_training_loader(),
+ val_dataloaders=self._build_validation_loader(),
+ ckpt_path=None,
+ )
+ else:
+
+ if self.only_tune_refer:
+ for name, para in self.algo.diffusion_model.named_parameters():
+ para.requires_grad_(False)
+ if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
+ para.requires_grad_(True)
+
+ trainer.fit(
+ self.algo,
+ train_dataloaders=self._build_training_loader(),
+ val_dataloaders=self._build_validation_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+
+ def validation(self) -> None:
+ """
+ All validation happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.validation.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ logger=self.logger,
+ devices="auto",
+ num_nodes=self.cfg.num_nodes,
+ strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
+ callbacks=callbacks,
+ # limit_val_batches=self.cfg.validation.limit_batch,
+ limit_val_batches=self.cfg.validation.limit_batch,
+ precision=self.cfg.validation.precision,
+ detect_anomaly=False, # self.cfg.debug,
+ inference_mode=self.cfg.validation.inference_mode,
+ )
+
+ if self.customized_load:
+
+ if self.load_vae:
+ load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
+ load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
+ else:
+ load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
+
+ if self.load_t_to_r:
+ param_list = []
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 't_' in name and 't_embedder' not in name:
+ print(name)
+ param_list.append(para)
+
+ it = 0
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 'r_' in name:
+ para.requires_grad_(False)
+ try:
+ para.copy_(param_list[it].detach().cpu())
+ except:
+ import pdb;pdb.set_trace()
+ para.requires_grad_(True)
+ it += 1
+
+ if self.zero_init_gate:
+ for name, para in self.algo.diffusion_model.named_parameters():
+ if 'r_adaLN_modulation' in name:
+ para.requires_grad_(False)
+ para[2*1024:3*1024] = 0
+ para[5*1024:6*1024] = 0
+ para.requires_grad_(True)
+
+ trainer.validate(
+ self.algo,
+ dataloaders=self._build_validation_loader(),
+ ckpt_path=None,
+ )
+ else:
+ trainer.validate(
+ self.algo,
+ dataloaders=self._build_validation_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+
+ def test(self) -> None:
+ """
+ All testing happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.test.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ logger=self.logger,
+ devices="auto",
+ num_nodes=self.cfg.num_nodes,
+ strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
+ callbacks=callbacks,
+ limit_test_batches=self.cfg.test.limit_batch,
+ precision=self.cfg.test.precision,
+ detect_anomaly=False, # self.cfg.debug,
+ )
+
+ # Only load the checkpoint if only testing. Otherwise, it will have been loaded
+ # and further trained during train.
+ trainer.test(
+ self.algo,
+ dataloaders=self._build_test_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.validation.compile:
+ self.algo = torch.compile(self.algo)
+
+
+ def interactive(self):
+
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.validation.compile:
+ self.algo = torch.compile(self.algo)
+
+ if self.customized_load:
+ load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.ckpt_path)
+ load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
+ load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
+ return self.algo
+ else:
+ raise NotImplementedError
+
+ def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
+ if split in ["training", "test", "validation"]:
+ return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
+ else:
+ raise NotImplementedError(f"split '{split}' is not implemented")
diff --git a/experiments/exp_pose.py b/experiments/exp_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..c94b9994dc85fff8dd6723df832f2369aa2eaa13
--- /dev/null
+++ b/experiments/exp_pose.py
@@ -0,0 +1,310 @@
+"""
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
+template [repo](https://github.com/buoyancy99/research-template).
+By its MIT license, you must keep the above sentence in `README.md`
+and the `LICENSE` file to credit the author.
+"""
+
+from abc import ABC, abstractmethod
+from typing import Optional, Union, Literal, List, Dict
+import pathlib
+import os
+
+import hydra
+import torch
+from lightning.pytorch.strategies.ddp import DDPStrategy
+
+import lightning.pytorch as pl
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.utilities import rank_zero_info
+
+from omegaconf import DictConfig
+
+from utils.print_utils import cyan
+from utils.distributed_utils import is_rank_zero
+from safetensors.torch import load_model
+from pathlib import Path
+from algorithms.worldmem import PosePrediction
+from datasets.video import MinecraftVideoPoseDataset
+
+
+torch.set_float32_matmul_precision("high")
+
+def load_custom_checkpoint(algo, optimizer, checkpoint_path):
+ if not checkpoint_path:
+ rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
+ return None
+
+ if not isinstance(checkpoint_path, Path):
+ checkpoint_path = Path(checkpoint_path)
+
+ if checkpoint_path.suffix == ".pt":
+ ckpt = torch.load(checkpoint_path, weights_only=True)
+ algo.load_state_dict(ckpt, strict=False)
+ elif checkpoint_path.suffix == ".ckpt":
+ ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
+ elif checkpoint_path.suffix == ".safetensors":
+ load_model(algo, checkpoint_path, strict=False)
+ elif os.path.isdir(checkpoint_path):
+ ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
+ if not ckpt_files:
+ raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
+ selected_ckpt = max(ckpt_files)
+ selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
+ print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
+
+ ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
+
+ rank_zero_info("Model weights loaded.")
+
+class PoseExperiment(ABC):
+ """
+ Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
+ flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
+ """
+
+ # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
+ compatible_algorithms = dict(
+ pose_prediction=PosePrediction
+ )
+
+ compatible_datasets = dict(
+ video_minecraft_pose=MinecraftVideoPoseDataset
+ )
+
+ def __init__(
+ self,
+ root_cfg: DictConfig,
+ logger: Optional[WandbLogger] = None,
+ ckpt_path: Optional[Union[str, pathlib.Path]] = None,
+ ) -> None:
+ """
+ Constructor
+
+ Args:
+ cfg: configuration file that contains everything about the experiment
+ logger: a pytorch-lightning WandbLogger instance
+ ckpt_path: an optional path to saved checkpoint
+ """
+ super().__init__()
+ self.root_cfg = root_cfg
+ self.cfg = root_cfg.experiment
+ self.debug = root_cfg.debug
+ self.logger = logger
+ self.ckpt_path = ckpt_path
+ self.algo = None
+ self.vae_path = "/cpfs01/user/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
+
+ def _build_algo(self):
+ """
+ Build the lightning module
+ :return: a pytorch-lightning module to be launched
+ """
+ algo_name = self.root_cfg.algorithm._name
+ if algo_name not in self.compatible_algorithms:
+ raise ValueError(
+ f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
+ "Make sure you define compatible_algorithms correctly and make sure that each key has "
+ "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
+ )
+ return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
+
+ def exec_task(self, task: str) -> None:
+ """
+ Executing a certain task specified by string. Each task should be a stage of experiment.
+ In most computer vision / nlp applications, tasks should be just train and test.
+ In reinforcement learning, you might have more stages such as collecting dataset etc
+
+ Args:
+ task: a string specifying a task implemented for this experiment
+ """
+ if hasattr(self, task) and callable(getattr(self, task)):
+ if is_rank_zero:
+ print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
+ getattr(self, task)()
+ else:
+ raise ValueError(
+ f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
+ )
+
+
+ def _build_trainer_callbacks(self):
+ callbacks = []
+ if self.logger:
+ callbacks.append(LearningRateMonitor("step", True))
+
+ def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ train_dataset = self._build_dataset("training")
+ shuffle = (
+ False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
+ )
+ if train_dataset:
+ return torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=self.cfg.training.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ validation_dataset = self._build_dataset("validation")
+ shuffle = (
+ False
+ if isinstance(validation_dataset, torch.utils.data.IterableDataset)
+ else self.cfg.validation.data.shuffle
+ )
+ if validation_dataset:
+ return torch.utils.data.DataLoader(
+ validation_dataset,
+ batch_size=self.cfg.validation.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
+ test_dataset = self._build_dataset("test")
+ shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
+ if test_dataset:
+ return torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=self.cfg.test.batch_size,
+ num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
+ shuffle=shuffle,
+ persistent_workers=True,
+ )
+ else:
+ return None
+
+ def training(self) -> None:
+ """
+ All training happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.training.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+ if self.logger:
+ callbacks.append(LearningRateMonitor("step", True))
+ if "checkpointing" in self.cfg.training:
+ callbacks.append(
+ ModelCheckpoint(
+ pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
+ **self.cfg.training.checkpointing,
+ )
+ )
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ devices="auto", # 自动选择设备
+ strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
+ logger=self.logger or False, # 简化写法
+ callbacks=callbacks,
+ gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
+ val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
+ limit_val_batches=self.cfg.validation.limit_batch,
+ check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
+ accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
+ precision=self.cfg.training.precision or 32, # 默认32位精度
+ detect_anomaly=False, # 默认关闭异常检测
+ num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
+ max_epochs=self.cfg.training.max_epochs,
+ max_steps=self.cfg.training.max_steps,
+ max_time=self.cfg.training.max_time
+ )
+
+ load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
+
+ trainer.fit(
+ self.algo,
+ train_dataloaders=self._build_training_loader(),
+ val_dataloaders=self._build_validation_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+
+ def validation(self) -> None:
+ """
+ All validation happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.validation.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ logger=self.logger,
+ devices="auto",
+ num_nodes=self.cfg.num_nodes,
+ strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
+ callbacks=callbacks,
+ # limit_val_batches=self.cfg.validation.limit_batch,
+ limit_val_batches=self.cfg.validation.limit_batch,
+ precision=self.cfg.validation.precision,
+ detect_anomaly=False, # self.cfg.debug,
+ inference_mode=self.cfg.validation.inference_mode,
+ )
+
+ load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
+
+ trainer.validate(
+ self.algo,
+ dataloaders=self._build_validation_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+
+ def test(self) -> None:
+ """
+ All testing happens here
+ """
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.test.compile:
+ self.algo = torch.compile(self.algo)
+
+ callbacks = []
+
+ trainer = pl.Trainer(
+ accelerator="auto",
+ logger=self.logger,
+ devices="auto",
+ num_nodes=self.cfg.num_nodes,
+ strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
+ callbacks=callbacks,
+ limit_test_batches=self.cfg.test.limit_batch,
+ precision=self.cfg.test.precision,
+ detect_anomaly=False, # self.cfg.debug,
+ )
+
+ # Only load the checkpoint if only testing. Otherwise, it will have been loaded
+ # and further trained during train.
+ trainer.test(
+ self.algo,
+ dataloaders=self._build_test_loader(),
+ ckpt_path=self.ckpt_path,
+ )
+ if not self.algo:
+ self.algo = self._build_algo()
+ if self.cfg.validation.compile:
+ self.algo = torch.compile(self.algo)
+
+ def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
+ if split in ["training", "test", "validation"]:
+ return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
+ else:
+ raise NotImplementedError(f"split '{split}' is not implemented")
+
+
diff --git a/experiments/exp_video.py b/experiments/exp_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7db48f533fb178f23c19043c582349203b10843
--- /dev/null
+++ b/experiments/exp_video.py
@@ -0,0 +1,25 @@
+from datasets.video import (
+ MinecraftVideoDataset,
+ MinecraftVideoPoseDataset
+)
+
+from algorithms.worldmem import WorldMemMinecraft
+from algorithms.worldmem import PosePrediction
+from .exp_base import BaseLightningExperiment
+
+
+class VideoPredictionExperiment(BaseLightningExperiment):
+ """
+ A video prediction experiment
+ """
+
+ compatible_algorithms = dict(
+ df_video_worldmemminecraft=WorldMemMinecraft,
+ pose_prediction=PosePrediction
+ )
+
+ compatible_datasets = dict(
+ # video datasets
+ video_minecraft=MinecraftVideoDataset,
+ video_minecraft_pose=MinecraftVideoPoseDataset
+ )
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a19cfdf54471f9b82611ca52d518b3df60ccb59
--- /dev/null
+++ b/main.py
@@ -0,0 +1,219 @@
+"""
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
+template [repo](https://github.com/buoyancy99/research-template).
+By its MIT license, you must keep the above sentence in `README.md`
+and the `LICENSE` file to credit the author.
+
+Main file for the project. This will create and run new experiments and load checkpoints from wandb.
+Borrowed part of the code from David Charatan and wandb.
+"""
+
+import sys
+import subprocess
+import time
+from pathlib import Path
+
+import hydra
+from omegaconf import DictConfig, OmegaConf
+from omegaconf.omegaconf import open_dict
+
+from utils.print_utils import cyan
+from utils.ckpt_utils import download_latest_checkpoint, is_run_id
+from utils.cluster_utils import submit_slurm_job
+from utils.distributed_utils import is_rank_zero
+
+def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = '*.ckpt'):
+ # 获取文件夹中所有符合 pattern 的文件
+ checkpoint_files = list(checkpoint_folder.glob(pattern))
+ if not checkpoint_files:
+ return None # 如果没有找到 checkpoint 文件,返回 None
+ # 根据文件修改时间(st_mtime)选取最新的文件
+ latest_checkpoint = max(checkpoint_files, key=lambda f: f.stat().st_mtime)
+ return latest_checkpoint
+
+def run_local(cfg: DictConfig):
+ # delay some imports in case they are not needed in non-local envs for submission
+ from experiments import build_experiment
+ from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
+
+ # Get yaml names
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
+
+ with open_dict(cfg):
+ if cfg_choice["experiment"] is not None:
+ cfg.experiment._name = cfg_choice["experiment"]
+ if cfg_choice["dataset"] is not None:
+ cfg.dataset._name = cfg_choice["dataset"]
+ if cfg_choice["algorithm"] is not None:
+ cfg.algorithm._name = cfg_choice["algorithm"]
+
+ # import pdb;pdb.set_trace()
+ # Set up the output directory.
+ output_dir = getattr(cfg, "output_dir", None)
+ if output_dir is not None:
+ OmegaConf.set_readonly(hydra_cfg, False)
+ hydra_cfg.runtime.output_dir = output_dir
+ OmegaConf.set_readonly(hydra_cfg, True)
+
+ output_dir = Path(hydra_cfg.runtime.output_dir)
+
+ if is_rank_zero:
+ print(cyan(f"Outputs will be saved to:"), output_dir)
+ (output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
+ (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)
+
+ # Set up logging with wandb.
+ if cfg.wandb.mode != "disabled":
+ # If resuming, merge into the existing run on wandb.
+ resume = cfg.get("resume", None)
+ name = f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" if resume is None else None
+
+ if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
+ logger_cls = OfflineWandbLogger
+ else:
+ logger_cls = SpaceEfficientWandbLogger
+
+ offline = cfg.wandb.mode != "online"
+ logger = logger_cls(
+ name=name,
+ save_dir=str(output_dir),
+ offline=offline,
+ entity=cfg.wandb.entity,
+ project=cfg.wandb.project,
+ log_model=False,
+ config=OmegaConf.to_container(cfg),
+ id=resume,
+ resume="auto"
+ )
+
+ else:
+ logger = None
+
+ # Load ckpt
+ resume = cfg.get("resume", None)
+ load = cfg.get("load", None)
+ checkpoint_path = None
+ load_id = None
+ if load and not is_run_id(load):
+ checkpoint_path = load
+ if resume:
+ load_id = resume
+ elif load and is_run_id(load):
+ load_id = load
+ else:
+ load_id = None
+
+ if load_id:
+ run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
+ checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt"
+ checkpoint_path = output_dir / get_latest_checkpoint(output_dir / "checkpoints")
+
+ if checkpoint_path and is_rank_zero:
+ print(f"Will load checkpoint from {checkpoint_path}")
+
+ # launch experiment
+ experiment = build_experiment(cfg, logger, checkpoint_path)
+ for task in cfg.experiment.tasks:
+ experiment.exec_task(task)
+
+
+def run_slurm(cfg: DictConfig):
+ python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
+ project_root = Path.cwd()
+ while not (project_root / ".git").exists():
+ project_root = project_root.parent
+ if project_root == Path("/"):
+ raise Exception("Could not find repo directory!")
+
+ slurm_log_dir = submit_slurm_job(
+ cfg,
+ python_args,
+ project_root,
+ )
+
+ if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
+ print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
+ osh_command_dir = project_root / ".wandb_osh_command_dir"
+
+ osh_proc = None
+ # if click.confirm("Do you want us to run the sync loop for you?", default=True):
+ osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
+ print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
+ print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
+ print(
+ f"You can manually start a sync loop later by running the following:",
+ cyan(f"wandb-osh --command-dir {osh_command_dir}"),
+ )
+
+ print(
+ "Once the job gets allocated and starts running, we will print a command below "
+ "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
+ )
+ msg = f"tail -f {slurm_log_dir}/* \n"
+ try:
+ while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
+ time.sleep(1)
+ print(cyan("To trace the outputs and errors, run the following command:"), msg)
+ except KeyboardInterrupt:
+ print("Keyboard interrupt detected. Exiting...")
+ print(
+ cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
+ msg,
+ )
+
+
+@hydra.main(
+ version_base=None,
+ config_path="configurations",
+ config_name="config",
+)
+def run(cfg: DictConfig):
+ if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
+ with open_dict(cfg):
+ if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
+ cfg.wandb.mode = "offline"
+
+ if "name" not in cfg:
+ raise ValueError("must specify a name for the run with command line argument '+name=[name]'")
+
+ if not cfg.wandb.get("entity", None):
+ raise ValueError(
+ "must specify wandb entity in 'configurations/config.yaml' or with command line"
+ " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
+ " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
+ )
+
+ if cfg.wandb.project is None:
+ cfg.wandb.project = str(Path(__file__).parent.name)
+
+ # If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint.
+ resume = cfg.get("resume", None)
+ load = cfg.get("load", None)
+
+ if resume and load:
+ raise ValueError(
+ "When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud"
+ "and `load` should not be specified."
+ )
+
+ if resume:
+ load_id = resume
+ elif load and is_run_id(load):
+ load_id = load
+ else:
+ load_id = None
+
+ # if load_id and "_on_compute_node" not in cfg:
+ # run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
+ # download_latest_checkpoint(run_path, Path("outputs/downloaded"))
+
+ if "cluster" in cfg and not "_on_compute_node" in cfg:
+ print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
+ run_slurm(cfg)
+ else:
+ run_local(cfg)
+
+
+if __name__ == "__main__":
+ run() # pylint: disable=no-value-for-parameter
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f0251e0578279f5fa4ea6937d0786aaab057f5c2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+torch~=2.4.0
+torchvision~=0.19.1
+lightning~=2.1.2
+wandb~=0.17.0
+hydra-core~=1.3.2
+omegaconf~=2.3.0
+torchmetrics[image]==0.11.4
+wandb-osh==1.2.1
+gluonts[torch]==0.13.1
+pytorchvideo~=0.1.5
+colorama
+tqdm
+opencv-python
+matplotlib
+click
+moviepy==1.0.3
+imageio
+einops
+pandas
+pyzmq
+pyrealsense2
+internetarchive
+h5py
+rotary_embedding_torch
+diffusers
+timm
\ No newline at end of file
diff --git a/scripts/README.md b/scripts/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e13a23d87572d924da424714e5e92e7e818ec279
--- /dev/null
+++ b/scripts/README.md
@@ -0,0 +1,10 @@
+# scirpts
+
+`scripts` folder contains bash scripts for you to scale up your project on cloud.
+Don't put your jupyter notebooks here! They belongs to `debug` folder.
+
+General scripts that are useful for all projects can be put in the `script` folder directly.
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/scripts/dummy_script.sh b/scripts/dummy_script.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3f638d0c841837769a467ab9fdb1ab486ccb4e9f
--- /dev/null
+++ b/scripts/dummy_script.sh
@@ -0,0 +1 @@
+echo 'hello world'
\ No newline at end of file
diff --git a/split_checkpoint.py b/split_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..c136afcd92ea716754b2e33d6f9318308816ea8b
--- /dev/null
+++ b/split_checkpoint.py
@@ -0,0 +1,9 @@
+import torch
+
+ckpt_path = "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
+checkpoint = torch.load(ckpt_path, map_location="cpu") # map_location 可根据需要更换
+
+state_dict = checkpoint['state_dict']
+pose_prediction_model_dict = {k.replace('pose_prediction_model.', ''): v for k, v in state_dict.items() if k.startswith('pose_prediction_model.')}
+
+torch.save({'state_dict': pose_prediction_model_dict}, "pose_prediction_model_only.ckpt")
\ No newline at end of file
diff --git a/utils/README.md b/utils/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0152f0645fbd11637c5b2217fe3021938ba33e3
--- /dev/null
+++ b/utils/README.md
@@ -0,0 +1,7 @@
+# utils
+
+This is where you can put useful utilities like visualization, 3d conversion, logging etc
+
+---
+
+This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a53f42f9450cbbe2f79a953ce458be6f42c344ea
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/utils/__pycache__/ckpt_utils.cpython-310.pyc b/utils/__pycache__/ckpt_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f65b4b2ba7cd2e84b4ca7d4c7a0101c017415d06
Binary files /dev/null and b/utils/__pycache__/ckpt_utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/cluster_utils.cpython-310.pyc b/utils/__pycache__/cluster_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cee453b9af03745da328207dc4ec008e9a356ae
Binary files /dev/null and b/utils/__pycache__/cluster_utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/distributed_utils.cpython-310.pyc b/utils/__pycache__/distributed_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..581b2f13cf762c4a664117beded39d2ee05c993f
Binary files /dev/null and b/utils/__pycache__/distributed_utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/logging_utils.cpython-310.pyc b/utils/__pycache__/logging_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d469010f79e4bcdc9f4e58c0db6330cd05e0f6d1
Binary files /dev/null and b/utils/__pycache__/logging_utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/print_utils.cpython-310.pyc b/utils/__pycache__/print_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d03845855a76c7ef5f1e13580e2eb3e7f9548d51
Binary files /dev/null and b/utils/__pycache__/print_utils.cpython-310.pyc differ
diff --git a/utils/__pycache__/wandb_utils.cpython-310.pyc b/utils/__pycache__/wandb_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0fb23b0c7f876b6de113d98779da5500ebdaca9
Binary files /dev/null and b/utils/__pycache__/wandb_utils.cpython-310.pyc differ
diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aa27f102f50b160b0c7ff5ed5b713f7bfe77b8b
--- /dev/null
+++ b/utils/ckpt_utils.py
@@ -0,0 +1,32 @@
+from pathlib import Path
+import wandb
+
+
+def is_run_id(run_id: str) -> bool:
+ """Check if a string is a run ID."""
+ return len(run_id) == 8 and run_id.isalnum()
+
+
+def version_to_int(artifact) -> int:
+ """Convert versions of the form vX to X. For example, v12 to 12."""
+ return int(artifact.version[1:])
+
+
+def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path:
+ api = wandb.Api()
+ run = api.run(run_path)
+
+ # Find the latest saved model checkpoint.
+ latest = None
+ for artifact in run.logged_artifacts():
+ if artifact.type != "model" or artifact.state != "COMMITTED":
+ continue
+
+ if latest is None or version_to_int(artifact) > version_to_int(latest):
+ latest = artifact
+
+ # Download the checkpoint.
+ download_dir.mkdir(exist_ok=True, parents=True)
+ root = download_dir / run_path
+ latest.download(root=root)
+ return root / "model.ckpt"
diff --git a/utils/cluster_utils.py b/utils/cluster_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae29488a44a59888ca7a97feb55e39447b2b393
--- /dev/null
+++ b/utils/cluster_utils.py
@@ -0,0 +1,40 @@
+"""
+utils for submitting to clusters, such as slurm
+"""
+
+import os
+from omegaconf import DictConfig, OmegaConf
+from datetime import datetime
+from pathlib import Path
+
+from utils.print_utils import cyan
+
+# This is set below.
+REPO_DIR = None
+
+
+def submit_slurm_job(
+ cfg: DictConfig,
+ python_args: str,
+ project_root: Path,
+):
+ log_dir = project_root / "slurm_logs" / f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-{cfg.name}"
+ log_dir.mkdir(exist_ok=True, parents=True)
+ (project_root / "slurm_logs" / "latest").unlink(missing_ok=True)
+ (project_root / "slurm_logs" / "latest").symlink_to(log_dir, target_is_directory=True)
+
+ params = dict(name=cfg.name, log_dir=log_dir, project_root=project_root, python_args=python_args)
+ params.update(cfg.cluster.params)
+
+ slurm_script = cfg.cluster.launch_template.format(**params)
+
+ slurm_script_path = log_dir / "job.slurm"
+ with slurm_script_path.open("w") as f:
+ f.write(slurm_script)
+
+ os.system(f"chmod +x {slurm_script_path}")
+ os.system(f"sbatch {slurm_script_path}")
+
+ print(f"\n{cyan('script:')} {slurm_script_path}\n{cyan('slurm errors and logs:')} {log_dir}\n")
+
+ return log_dir
diff --git a/utils/distributed_utils.py b/utils/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dde3e98aef29b3fe6b7eb1c58f589b1d0d6c99ce
--- /dev/null
+++ b/utils/distributed_utils.py
@@ -0,0 +1,3 @@
+import wandb
+
+is_rank_zero = wandb.run is not None
diff --git a/utils/logging_utils.py b/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5b8ac222ff86634440ccc8efa0c3a1f22660843
--- /dev/null
+++ b/utils/logging_utils.py
@@ -0,0 +1,435 @@
+from typing import Optional
+import wandb
+import numpy as np
+import torch
+
+import matplotlib.pyplot as plt
+import cv2
+import matplotlib.pyplot as plt
+from tqdm import trange, tqdm
+import matplotlib.animation as animation
+from pathlib import Path
+
+plt.set_loglevel("warning")
+
+from torchmetrics.functional import mean_squared_error, peak_signal_noise_ratio
+from torchmetrics.functional import (
+ structural_similarity_index_measure,
+ universal_image_quality_index,
+)
+from algorithms.common.metrics import (
+ FrechetVideoDistance,
+ LearnedPerceptualImagePatchSimilarity,
+ FrechetInceptionDistance,
+)
+
+
+# FIXME: clean up & check this util
+def log_video(
+ observation_hat,
+ observation_gt=None,
+ step=0,
+ namespace="train",
+ prefix="video",
+ context_frames=0,
+ color=(255, 0, 0),
+ logger=None,
+):
+ """
+ take in video tensors in range [-1, 1] and log into wandb
+
+ :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width)
+ :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width)
+ :param step: an int indicating the step number
+ :param namespace: a string specify a name space this video logging falls under, e.g. train, val
+ :param prefix: a string specify a prefix for the video name
+ :param context_frames: an int indicating how many frames in observation_hat are ground truth given as context
+ :param color: a tuple of 3 numbers specifying the color of the border for ground truth frames
+ :param logger: optional logger to use. use global wandb if not specified
+ """
+ if not logger:
+ logger = wandb
+
+ # observation_gt = torch.zeros_like(observation_hat)
+ # observation_hat[:context_frames] = observation_gt[:context_frames]
+ # Add red border of 1 pixel width to the context frames
+ # for i, c in enumerate(color):
+ # c = c / 255.0
+ # observation_hat[:context_frames, :, i, [0, -1], :] = c
+ # observation_hat[:context_frames, :, i, :, [0, -1]] = c
+
+ # if observation_gt is not None:
+ # observation_gt[:context_frames, :, i, [0, -1], :] = c
+ # observation_gt[:context_frames, :, i, :, [0, -1]] = c
+
+ if observation_gt is not None:
+ video = torch.cat([observation_hat, observation_gt], -2).detach().cpu().numpy()
+ else:
+ video = torch.cat([observation_hat], -1).detach().cpu().numpy()
+ video = np.transpose(np.clip(video, a_min=0.0, a_max=1.0) * 255, (1, 0, 2, 3, 4)).astype(np.uint8)
+ # video[..., 1:] = video[..., :1] # remove framestack, only visualize current frame
+ n_samples = len(video)
+ # use wandb directly here since pytorch lightning doesn't support logging videos yet
+ for i in range(n_samples):
+ logger.log(
+ {
+ f"{namespace}/{prefix}_{i}": wandb.Video(video[i], fps=5),
+ f"trainer/global_step": step,
+ }
+ )
+
+
+def get_validation_metrics_for_videos(
+ observation_hat,
+ observation_gt,
+ lpips_model: Optional[LearnedPerceptualImagePatchSimilarity] = None,
+ fid_model: Optional[FrechetInceptionDistance] = None,
+ fvd_model: Optional[FrechetVideoDistance] = None,
+):
+ """
+ :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width)
+ :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width)
+ :param lpips_model: a LearnedPerceptualImagePatchSimilarity object from algorithm.common.metrics
+ :param fid_model: a FrechetInceptionDistance object from algorithm.common.metrics
+ :param fvd_model: a FrechetVideoDistance object from algorithm.common.metrics
+ :return: a tuple of metrics
+ """
+ frame, batch, channel, height, width = observation_hat.shape
+ output_dict = {}
+ observation_gt = observation_gt.type_as(observation_hat) # some metrics don't fully support fp16
+
+ if frame < 9:
+ fvd_model = None # FVD requires at least 9 frames
+
+ observation_hat = observation_hat.float()
+ observation_gt = observation_gt.float()
+
+ # observation_hat = observation_hat.float().to(next(lpips_model.parameters()).device)
+ # observation_gt = observation_gt.float().to(next(lpips_model.parameters()).device)
+ # if fvd_model is not None:
+ # output_dict["fvd"] = fvd_model.compute(torch.clamp(observation_hat, -1.0, 1.0), torch.clamp(observation_gt, -1.0, 1.0))
+
+ frame_wise_psnr = []
+ for f in range(observation_hat.shape[0]):
+ frame_wise_psnr.append(peak_signal_noise_ratio(observation_hat[f], observation_gt[f], data_range=2.0))
+ frame_wise_psnr = torch.stack(frame_wise_psnr)
+
+ output_dict["frame_wise_psnr"] = frame_wise_psnr
+ observation_hat = observation_hat.view(-1, channel, height, width)
+ observation_gt = observation_gt.view(-1, channel, height, width)
+
+ output_dict["mse"] = mean_squared_error(observation_hat, observation_gt)
+
+ output_dict["psnr"] = peak_signal_noise_ratio(observation_hat, observation_gt, data_range=2.0)
+ # output_dict["ssim"] = structural_similarity_index_measure(observation_hat, observation_gt, data_range=2.0)
+ # output_dict["uiqi"] = universal_image_quality_index(observation_hat, observation_gt)
+ # operations for LPIPS and FID
+ observation_hat = torch.clamp(observation_hat, -1.0, 1.0)
+ observation_gt = torch.clamp(observation_gt, -1.0, 1.0)
+
+ if lpips_model is not None:
+ lpips_model.update(observation_hat, observation_gt)
+ lpips = lpips_model.compute().item()
+ # Reset the states of non-functional metrics
+ output_dict["lpips"] = lpips
+ lpips_model.reset()
+
+ if fid_model is not None:
+ observation_hat_uint8 = ((observation_hat + 1.0) / 2 * 255).type(torch.uint8)
+ observation_gt_uint8 = ((observation_gt + 1.0) / 2 * 255).type(torch.uint8)
+ fid_model.update(observation_gt_uint8, real=True)
+ fid_model.update(observation_hat_uint8, real=False)
+ fid = fid_model.compute()
+ output_dict["fid"] = fid
+ # Reset the states of non-functional metrics
+ fid_model.reset()
+
+ return output_dict
+
+
+def is_grid_env(env_id):
+ return "maze2d" in env_id or "diagonal2d" in env_id
+
+
+def get_maze_grid(env_id):
+ # import gym
+ # maze_string = gym.make(env_id).str_maze_spec
+ if "large" in env_id:
+ maze_string = "############\\#OOOO#OOOOO#\\#O##O#O#O#O#\\#OOOOOO#OOO#\\#O####O###O#\\#OO#O#OOOOO#\\##O#O#O#O###\\#OO#OOO#OGO#\\############"
+ if "medium" in env_id:
+ maze_string = "########\\#OO##OO#\\#OO#OOO#\\##OOO###\\#OO#OOO#\\#O#OO#O#\\#OOO#OG#\\########"
+ if "umaze" in env_id:
+ maze_string = "#####\\#GOO#\\###O#\\#OOO#\\#####"
+ lines = maze_string.split("\\")
+ grid = [line[1:-1] for line in lines]
+ return grid[1:-1]
+
+
+def get_random_start_goal(env_id, batch_size):
+ maze_grid = get_maze_grid(env_id)
+ s2i = {"O": 0, "#": 1, "G": 2}
+ maze_grid = [[s2i[s] for s in r] for r in maze_grid]
+ maze_grid = np.array(maze_grid)
+ x, y = np.nonzero(maze_grid == 0)
+ indices = np.random.randint(len(x), size=batch_size)
+ start = np.stack([x[indices], y[indices]], -1) + 1
+ x, y = np.nonzero(maze_grid == 2)
+ goal = np.concatenate([x, y], -1)
+ goal = np.tile(goal[None, :], (batch_size, 1)) + 1
+ return start, goal
+
+
+def plot_maze_layout(ax, maze_grid):
+ ax.clear()
+
+ if maze_grid is not None:
+ for i, row in enumerate(maze_grid):
+ for j, cell in enumerate(row):
+ if cell == "#":
+ square = plt.Rectangle((i + 0.5, j + 0.5), 1, 1, edgecolor="black", facecolor="black")
+ ax.add_patch(square)
+
+ ax.set_aspect("equal")
+ ax.grid(True, color="white", linewidth=4)
+ ax.set_axisbelow(True)
+ ax.spines["top"].set_linewidth(4)
+ ax.spines["right"].set_linewidth(4)
+ ax.spines["bottom"].set_linewidth(4)
+ ax.spines["left"].set_linewidth(4)
+ ax.set_facecolor("lightgray")
+ ax.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ top=False,
+ left=False,
+ right=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+ ax.set_xticks(np.arange(0.5, len(maze_grid) + 0.5))
+ ax.set_yticks(np.arange(0.5, len(maze_grid[0]) + 0.5))
+ ax.set_xlim(0.5, len(maze_grid) + 0.5)
+ ax.set_ylim(0.5, len(maze_grid[0]) + 0.5)
+ ax.grid(True, color="white", which="minor", linewidth=4)
+
+
+def plot_start_goal(ax, start_goal: None):
+ def draw_star(center, radius, num_points=5, color="black"):
+ angles = np.linspace(0.0, 2 * np.pi, num_points, endpoint=False) + 5 * np.pi / (2 * num_points)
+ inner_radius = radius / 2.0
+
+ points = []
+ for angle in angles:
+ points.extend(
+ [
+ center[0] + radius * np.cos(angle),
+ center[1] + radius * np.sin(angle),
+ center[0] + inner_radius * np.cos(angle + np.pi / num_points),
+ center[1] + inner_radius * np.sin(angle + np.pi / num_points),
+ ]
+ )
+
+ star = plt.Polygon(np.array(points).reshape(-1, 2), color=color)
+ ax.add_patch(star)
+
+ start_x, start_y = start_goal[0]
+ start_outer_circle = plt.Circle((start_x, start_y), 0.16, facecolor="white", edgecolor="black")
+ ax.add_patch(start_outer_circle)
+ start_inner_circle = plt.Circle((start_x, start_y), 0.08, color="black")
+ ax.add_patch(start_inner_circle)
+
+ goal_x, goal_y = start_goal[1]
+ goal_outer_circle = plt.Circle((goal_x, goal_y), 0.16, facecolor="white", edgecolor="black")
+ ax.add_patch(goal_outer_circle)
+ draw_star((goal_x, goal_y), radius=0.08)
+
+
+def make_trajectory_images(env_id, trajectory, batch_size, start, goal, plot_end_points=True):
+ images = []
+ for batch_idx in range(batch_size):
+ fig, ax = plt.subplots()
+ if is_grid_env(env_id):
+ maze_grid = get_maze_grid(env_id)
+ else:
+ maze_grid = None
+ plot_maze_layout(ax, maze_grid)
+ ax.scatter(trajectory[:, batch_idx, 0], trajectory[:, batch_idx, 1], c=np.arange(len(trajectory)), cmap="Reds"),
+ if plot_end_points:
+ start_goal = (start[batch_idx], goal[batch_idx])
+ plot_start_goal(ax, start_goal)
+ # plt.title(f"sample_{batch_idx}")
+ fig.tight_layout()
+ fig.canvas.draw()
+ img_shape = fig.canvas.get_width_height()[::-1] + (4,)
+ img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).copy().reshape(img_shape)
+ images.append(img)
+
+ plt.close()
+ return images
+
+
+def make_convergence_animation(
+ env_id,
+ plan_history,
+ trajectory,
+ start,
+ goal,
+ open_loop_horizon,
+ namespace,
+ interval=100,
+ plot_end_points=True,
+ batch_idx=0,
+):
+ # - plan_history: contains for each time step all the MPC predicted plans for each pyramid noise level.
+ # Structured as a list of length (episode_len // open_loop_horizon), where each
+ # element corresponds to a control_time_step and stores a list of length pyramid_height,
+ # where each element is a plan at a different pyramid noise level and stored as a tensor of
+ # shape (episode_len // open_loop_horizon - control_time_step,
+ # batch_size, x_stacked_shape)
+
+ # select index and prune history
+ start, goal = start[batch_idx], goal[batch_idx]
+ trajectory = trajectory[:, batch_idx]
+ plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history]
+ trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon)
+
+ # animate the convergence of the first plan
+ fig, ax = plt.subplots()
+ if "large" in env_id:
+ fig.set_size_inches(3.5, 5)
+ else:
+ fig.set_size_inches(3, 3)
+ ax.set_axis_off()
+ fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
+
+ if is_grid_env(env_id):
+ maze_grid = get_maze_grid(env_id)
+ else:
+ maze_grid = None
+
+ def update(frame):
+ plot_maze_layout(ax, maze_grid)
+
+ plan_history_m = plan_history[0][frame]
+ plan_history_m = plan_history_m.numpy()
+ ax.scatter(
+ plan_history_m[:, 0],
+ plan_history_m[:, 1],
+ c=np.arange(len(plan_history_m))[::-1],
+ cmap="Reds",
+ )
+
+ if plot_end_points:
+ plot_start_goal(ax, (start, goal))
+
+ frames = tqdm(range(len(plan_history[0])), desc="Making convergence animation")
+ ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval)
+ prefix = wandb.run.id if wandb.run is not None else env_id
+ filename = f"/tmp/{prefix}_{namespace}_convergence.mp4"
+ ani.save(filename, writer="ffmpeg", fps=5)
+ return filename
+
+
+def prune_history(plan_history, trajectory, goal, open_loop_horizon):
+ dist = np.linalg.norm(
+ trajectory[:, :2] - np.array(goal)[None],
+ axis=-1,
+ )
+ reached = dist < 0.2
+ if reached.any():
+ cap_idx = np.argmax(reached)
+ trajectory = trajectory[: cap_idx + open_loop_horizon + 1]
+ plan_history = plan_history[: cap_idx // open_loop_horizon + 2]
+
+ pruned_plan_history = []
+ for plans in plan_history:
+ pruned_plan_history.append([])
+ for m in range(len(plans)):
+ plan = plans[m]
+ pruned_plan_history[-1].append(plan)
+ plan = pruned_plan_history[-1][-1]
+ dist = np.linalg.norm(plan.numpy()[:, :2] - np.array(goal)[None], axis=-1)
+ reached = dist < 0.2
+ if reached.any():
+ cap_idx = np.argmax(reached) + 1
+ pruned_plan_history[-1] = [p[:cap_idx] for p in pruned_plan_history[-1]]
+ return trajectory, pruned_plan_history
+
+
+def make_mpc_animation(
+ env_id,
+ plan_history,
+ trajectory,
+ start,
+ goal,
+ open_loop_horizon,
+ namespace,
+ interval=100,
+ plot_end_points=True,
+ batch_idx=0,
+):
+ # - plan_history: contains for each time step all the MPC predicted plans for each pyramid noise level.
+ # Structured as a list of length (episode_len // open_loop_horizon), where each
+ # element corresponds to a control_time_step and stores a list of length pyramid_height,
+ # where each element is a plan at a different pyramid noise level and stored as a tensor of
+ # shape (episode_len // open_loop_horizon - control_time_step,
+ # batch_size, x_stacked_shape)
+
+ # select index and prune history
+ start, goal = start[batch_idx], goal[batch_idx]
+ trajectory = trajectory[:, batch_idx]
+ plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history]
+ trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon)
+
+ # animate the convergence of the plans
+ fig, ax = plt.subplots()
+ if "large" in env_id:
+ fig.set_size_inches(3.5, 5)
+ else:
+ fig.set_size_inches(3, 3)
+ ax.set_axis_off()
+ fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
+ trajectory_colors = np.linspace(0, 1, len(trajectory))
+
+ if is_grid_env(env_id):
+ maze_grid = get_maze_grid(env_id)
+ else:
+ maze_grid = None
+
+ def update(frame):
+ control_time_step = 0
+ while frame >= 0:
+ frame -= len(plan_history[control_time_step])
+ control_time_step += 1
+ control_time_step -= 1
+ m = frame + len(plan_history[control_time_step])
+ num_steps_taken = 1 + open_loop_horizon * control_time_step
+ plot_maze_layout(ax, maze_grid)
+
+ plan_history_m = plan_history[control_time_step][m]
+ plan_history_m = plan_history_m.numpy()
+ ax.scatter(
+ trajectory[:num_steps_taken, 0],
+ trajectory[:num_steps_taken, 1],
+ c=trajectory_colors[:num_steps_taken],
+ cmap="Blues",
+ )
+ ax.scatter(
+ plan_history_m[:, 0],
+ plan_history_m[:, 1],
+ c=np.arange(len(plan_history_m))[::-1],
+ cmap="Reds",
+ )
+
+ if plot_end_points:
+ plot_start_goal(ax, (start, goal))
+
+ num_frames = sum([len(p) for p in plan_history])
+ frames = tqdm(range(num_frames), desc="Making MPC animation")
+ ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval)
+ prefix = wandb.run.id if wandb.run is not None else env_id
+ filename = f"/tmp/{prefix}_{namespace}_mpc.mp4"
+ ani.save(filename, writer="ffmpeg", fps=5)
+
+ return filename
diff --git a/utils/print_utils.py b/utils/print_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c9052267f0390c1e0068be6f3bac453d3c4d23
--- /dev/null
+++ b/utils/print_utils.py
@@ -0,0 +1,5 @@
+from colorama import Fore
+
+
+def cyan(x: str) -> str:
+ return f"{Fore.CYAN}{x}{Fore.RESET}"
diff --git a/utils/wandb_utils.py b/utils/wandb_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c4df932c82e2b4f2dd3510e39faa7ecfee19279
--- /dev/null
+++ b/utils/wandb_utils.py
@@ -0,0 +1,175 @@
+from pathlib import Path
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union
+from typing_extensions import override
+from functools import wraps
+import os
+from wandb_osh.hooks import TriggerWandbSyncHook
+import time
+from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
+from lightning.fabric.utilities.types import _PATH
+
+
+if TYPE_CHECKING:
+ from wandb.sdk.lib import RunDisabled
+ from wandb.wandb_run import Run
+
+
+class SpaceEfficientWandbLogger(WandbLogger):
+ """
+ A wandb logger that by default overrides artifacts to save space, instead of creating new version.
+ A variable expiration_days can be set to control how long older versions of artifacts are kept.
+ By default, the latest version is kept indefinitely, while older versions are kept for 5 days.
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ save_dir: _PATH = ".",
+ version: Optional[str] = None,
+ offline: bool = False,
+ dir: Optional[_PATH] = None,
+ id: Optional[str] = None,
+ anonymous: Optional[bool] = None,
+ project: Optional[str] = None,
+ log_model: Union[Literal["all"], bool] = False,
+ experiment: Union["Run", "RunDisabled", None] = None,
+ prefix: str = "",
+ checkpoint_name: Optional[str] = None,
+ expiration_days: Optional[int] = 5,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ save_dir=save_dir,
+ version=version,
+ offline=False,
+ dir=dir,
+ id=id,
+ anonymous=anonymous,
+ project=project,
+ log_model=log_model,
+ experiment=experiment,
+ prefix=prefix,
+ checkpoint_name=checkpoint_name,
+ **kwargs,
+ )
+
+ super().__init__(
+ name=name,
+ save_dir=save_dir,
+ version=version,
+ offline=offline,
+ dir=dir,
+ id=id,
+ anonymous=anonymous,
+ project=project,
+ log_model=log_model,
+ experiment=experiment,
+ prefix=prefix,
+ checkpoint_name=checkpoint_name,
+ **kwargs,
+ )
+ self.expiration_days = expiration_days
+ self._last_artifacts = []
+
+ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
+ import wandb
+
+ # get checkpoints to be saved with associated score
+ checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
+
+ # log iteratively all new checkpoints
+ artifacts = []
+ for t, p, s, tag in checkpoints:
+ metadata = {
+ "score": s.item() if isinstance(s, Tensor) else s,
+ "original_filename": Path(p).name,
+ checkpoint_callback.__class__.__name__: {
+ k: getattr(checkpoint_callback, k)
+ for k in [
+ "monitor",
+ "mode",
+ "save_last",
+ "save_top_k",
+ "save_weights_only",
+ "_every_n_train_steps",
+ ]
+ # ensure it does not break if `ModelCheckpoint` args change
+ if hasattr(checkpoint_callback, k)
+ },
+ }
+ if not self._checkpoint_name:
+ self._checkpoint_name = f"model-{self.experiment.id}"
+
+ artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
+ artifact.add_file(p, name="model.ckpt")
+ aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
+ self.experiment.log_artifact(artifact, aliases=aliases)
+ # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
+ self._logged_model_time[p] = t
+ artifacts.append(artifact)
+
+ for artifact in self._last_artifacts:
+ if not self._offline:
+ artifact.wait()
+ artifact.ttl = timedelta(days=self.expiration_days)
+ artifact.save()
+ self._last_artifacts = artifacts
+
+
+class OfflineWandbLogger(SpaceEfficientWandbLogger):
+ """
+ Wraps WandbLogger to trigger offline sync hook occasionally.
+ This is useful when running on slurm clusters, many of which
+ only has internet on login nodes, not compute nodes.
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ save_dir: _PATH = ".",
+ version: Optional[str] = None,
+ offline: bool = False,
+ dir: Optional[_PATH] = None,
+ id: Optional[str] = None,
+ anonymous: Optional[bool] = None,
+ project: Optional[str] = None,
+ log_model: Union[Literal["all"], bool] = False,
+ experiment: Union["Run", "RunDisabled", None] = None,
+ prefix: str = "",
+ checkpoint_name: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ name=name,
+ save_dir=save_dir,
+ version=version,
+ offline=False,
+ dir=dir,
+ id=id,
+ anonymous=anonymous,
+ project=project,
+ log_model=log_model,
+ experiment=experiment,
+ prefix=prefix,
+ checkpoint_name=checkpoint_name,
+ **kwargs,
+ )
+ self._offline = offline
+ communication_dir = Path(".wandb_osh_command_dir")
+ communication_dir.mkdir(parents=True, exist_ok=True)
+ self.trigger_sync = TriggerWandbSyncHook(communication_dir)
+ self.last_sync_time = 0.0
+ self.min_sync_interval = 60
+ self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run")
+
+ @override
+ @rank_zero_only
+ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
+ out = super().log_metrics(metrics, step)
+ if time.time() - self.last_sync_time > self.min_sync_interval:
+ self.trigger_sync(self.wandb_dir)
+ self.last_sync_time = time.time()
+ return out
diff --git a/wandb/debug-cli.zeqi001.log b/wandb/debug-cli.zeqi001.log
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wandb/settings b/wandb/settings
new file mode 100644
index 0000000000000000000000000000000000000000..a741f65f08f6c0aad011897f7a19126787a186d5
--- /dev/null
+++ b/wandb/settings
@@ -0,0 +1,3 @@
+[default]
+mode = disabled
+