from typing import Callable, Dict, Union from torch import Tensor Tree = Union[Dict[str, "Tree"], Tensor] def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: """Merge nested dictionaries of tensors.""" if isinstance(trees[0], Tensor): return merge_fn(trees) else: return { key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] }