File size: 433 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from typing import Callable, Dict, Union
from torch import Tensor
Tree = Union[Dict[str, "Tree"], Tensor]
def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree:
"""Merge nested dictionaries of tensors."""
if isinstance(trees[0], Tensor):
return merge_fn(trees)
else:
return {
key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0]
}
|