File size: 4,903 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Tuple

try:
    import faiss
    from faiss import Index
except ImportError as e:
    raise ImportError(
        "You need to install FAISS library to perform ANN/KNN. Please check the official doc: "
        "https://github.com/facebookresearch/faiss/blob/main/INSTALL.md"
    )


def determine_devices(max_gpu_devices: int = 0) -> Tuple[int, bool]:
    """
    Determine which device we should use
    Args:
        max_gpu_devices: an integer value, define how many GPUs we'll use.
            -1 means all devices. 0 means there are no GPUs. Default is 0.

    Returns: number of devices and is it allowed to use CUDA device (True if yes)
    """
    n_devices_total = faiss.get_num_gpus()
    is_gpu = n_devices_total > 0

    if max_gpu_devices > 0 and is_gpu:
        num_devices = min(n_devices_total, max_gpu_devices)
    elif max_gpu_devices == -1 and is_gpu:
        num_devices = n_devices_total
    else:
        num_devices = 1
        is_gpu = False
    return num_devices, is_gpu


def _get_brute_index(emb_dim: int, dist_type: str) -> Index:
    if dist_type.lower() == 'ip':
        index = faiss.IndexFlatIP(emb_dim)
    elif dist_type.lower() == 'l2':
        index = faiss.IndexFlatL2(emb_dim)
    else:
        raise ValueError(f'Wrong distance type for FAISS Flat Index: {dist_type}')

    return index


def _get_ivf_index(
    emb_dim: int,
    n_objects: int,
    in_list_dist_type: str,
    centroid_dist_type: str,
    encode_residuals: bool
) -> Index:
    # according to the FAISS doc, this should be OK
    n_list = int(4 * (n_objects ** 0.5))

    if in_list_dist_type.lower() == 'ip':
        quannizer = faiss.IndexFlatIP(emb_dim)
    elif in_list_dist_type.lower() == 'l2':
        quannizer = faiss.IndexFlatL2(emb_dim)
    else:
        raise ValueError(f'Wrong distance type for FAISS quantizer: {in_list_dist_type}')

    if centroid_dist_type.lower() == 'ip':
        centroid_metric = faiss.METRIC_INNER_PRODUCT
    elif centroid_dist_type.lower() == 'l2':
        centroid_metric = faiss.METRIC_L2
    else:
        raise ValueError(f'Wrong distance type for FAISS index: {centroid_dist_type}')

    index = faiss.IndexIVFScalarQuantizer(
        quannizer,
        emb_dim,
        n_list,
        faiss.ScalarQuantizer.QT_fp16,  # TODO: should be optional?
        centroid_metric,
        encode_residuals
    )
    return index


def create_faiss_index(
    emb_dim: int,
    n_objects: int,
    n_probe: int = 10,
    max_gpu_devices: int = 0,
    encode_residuals: bool = True,
    in_list_dist_type: str = 'L2',
    centroid_dist_type: str = 'L2'
) -> Index:
    """
    Create IVF index (with IP or L2 dist), without adding data and training
    Args:
        emb_dim: size of each embedding
        n_objects: size of a trainset for index. Used to determine optimal type
            of index and its settings (will use bruteforce if `n_objects` is less than 20_000).
        n_probe: number of closest IVF-clusters to check for neighbours.
            Doesn't affect bruteforce-based search.
        max_gpu_devices: maximum amount of GPUs to use for ANN-index. 0 if run on CPU.
        encode_residuals: whether or not compute residuals. The residual vector is 
            the difference between a vector and the reconstruction that can be
            decoded from its representation in the index.
        in_list_dist_type: type of distance to calculate simmilarities within one IVF.
            Can be `IP` (for inner product) or `L2` distance. Case insensetive.
            If the index type is bruteforce (`n_objects` < 20_000), this variable will define
            the distane type for that bruteforce index. `centroid_dist_type` will be ignored.
        centroid_dist_type: type of distance to calculate simmilarities between a query 
            and cluster centroids. Can be `IP` (for inner product) or `L2` distance.
            Case insensetive.
    Returns: untrained FAISS-index
    """
    if n_objects < 20_000:
        # if less than 20_000 / (4 * sqrt(20_000)) ~= 35 points per cluster - make bruteforce
        # https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-below-1m-vectors-ivfk
        index = _get_brute_index(emb_dim=emb_dim, dist_type=in_list_dist_type)
    else:
        index = _get_ivf_index(
            emb_dim=emb_dim,
            n_objects=n_objects,
            in_list_dist_type=in_list_dist_type,
            centroid_dist_type=centroid_dist_type,
            encode_residuals=encode_residuals
        )

    index.nprobe = n_probe

    num_devices, is_gpu = determine_devices(max_gpu_devices)
    if is_gpu:
        cloner_options = faiss.GpuMultipleClonerOptions()
        cloner_options.shard = True  # split (not replicate) one index between GPUs
        index = faiss.index_cpu_to_gpus_list(index, cloner_options, list(range(num_devices)))

    return index