File size: 270 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
from jaxtyping import Int
from torch import Tensor
def add_third_context_index(
indices: Int[Tensor, "*batch 2"]
) -> Int[Tensor, "*batch 3"]:
left, right = indices.unbind(dim=-1)
return torch.stack((left, (left + right) // 2, right), dim=-1)
|