AnySplat / src /misc /collation.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
433 Bytes
from typing import Callable, Dict, Union
from torch import Tensor
Tree = Union[Dict[str, "Tree"], Tensor]
def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree:
"""Merge nested dictionaries of tensors."""
if isinstance(trees[0], Tensor):
return merge_fn(trees)
else:
return {
key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0]
}