File size: 1,371 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from einops import repeat
from jaxtyping import Int
from torch import Tensor

Index = Int[Tensor, "n n-1"]


def generate_heterogeneous_index(
    n: int,
    device: torch.device = torch.device("cpu"),
) -> tuple[Index, Index]:
    """Generate indices for all pairs except self-pairs."""
    arange = torch.arange(n, device=device)

    # Generate an index that represents the item itself.
    index_self = repeat(arange, "h -> h w", w=n - 1)

    # Generate an index that represents the other items.
    index_other = repeat(arange, "w -> h w", h=n).clone()
    index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu()
    index_other = index_other[:, :-1]

    return index_self, index_other


def generate_heterogeneous_index_transpose(
    n: int,
    device: torch.device = torch.device("cpu"),
) -> tuple[Index, Index]:
    """Generate an index that can be used to "transpose" the heterogeneous index.
    Applying the index a second time inverts the "transpose."
    """
    arange = torch.arange(n, device=device)
    ones = torch.ones((n, n), device=device, dtype=torch.int64)

    index_self = repeat(arange, "w -> h w", h=n).clone()
    index_self = index_self + ones.triu()

    index_other = repeat(arange, "h -> h w", w=n)
    index_other = index_other - (1 - ones.triu())

    return index_self[:, :-1], index_other[:, :-1]