File size: 270 Bytes
2568013
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch
from jaxtyping import Int
from torch import Tensor


def add_third_context_index(
    indices: Int[Tensor, "*batch 2"]
) -> Int[Tensor, "*batch 3"]:
    left, right = indices.unbind(dim=-1)
    return torch.stack((left, (left + right) // 2, right), dim=-1)