import torch from jaxtyping import Int from torch import Tensor def add_third_context_index( indices: Int[Tensor, "*batch 2"] ) -> Int[Tensor, "*batch 3"]: left, right = indices.unbind(dim=-1) return torch.stack((left, (left + right) // 2, right), dim=-1)