File size: 8,827 Bytes
b92a520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import os
import psutil
import time
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.db.utils import \
get_merged_train_dataset as get_db_merged_train_dataset
from tools.retro.external_libs import faiss, h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.index.utils import get_index_dir
from tools.retro.utils import GPTToTextDataset
from .chunk_dataset import get_chunk_dataset_map as get_query_dataset_map
def get_index(ondisk=False):
'''Read index from disk.'''
args = get_retro_args()
# Load index.
index_wrapper = IndexFactory.get_index(args.retro_index_type)
index_dir = get_index_dir()
added_index_path = index_wrapper.get_added_index_path()
if ondisk:
index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(added_index_path)
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch",
args.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe",
args.retro_query_nprobe)
return index
def embed_block(gpt_dataset, block, embedder):
'''Embed block of chunks.'''
text_block_dataset = torch.utils.data.Subset(
GPTToTextDataset(gpt_dataset),
range(*block["range"]),
)
return embedder.embed_text_dataset(text_block_dataset)
def query_embeddings(db_dataset, index,
embeddings, chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=True):
'''Query neighbors of a block of embeddings.'''
args = get_retro_args()
# Query neighbor ids.
if verbose: print_rank_0("search.")
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = \
index.search(embeddings, args.retro_query_num_neighbors_query)
if verbose: print_rank_0(" time : %.3f sec." % (time.time() - t))
# Filter banned neighbor ids.
if verbose: print_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), args.retro_query_num_neighbors_save),
fill_value=-1,
dtype="int64",
)
min_chunk_id, max_chunk_id = chunk_id_range
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1).
query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id]
if i >= 0 ]
# Filter row.
filtered_row = [ i for i in query_row
if tuple(db_dataset.doc_tuples[i].tolist())
not in sample_doc_tuples ]
filtered_row = filtered_row[:args.retro_query_num_neighbors_save]
filtered_row += \
[-1] * (args.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(db_dataset, index,
embeddings, chunk_id_range,
sample_map, n_chunks_per_sample):
query_neighbor_ids = []
filtered_neighbor_ids = []
# Query in sub-blocks.
partial_block_size = 1000
for partial_start_idx in tqdm(
range(0, len(embeddings), partial_block_size),
"search",
):
partial_end_idx = min(len(embeddings),
partial_start_idx + partial_block_size)
partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
partial_chunk_id_range = (
chunk_id_range[0] + partial_start_idx,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = \
query_embeddings(db_dataset, index,
partial_embeddings, partial_chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=False)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
# Concatenate.
query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(db_dataset, query_dataset,
index, embedder,
block):
'''Query neighbors of a dataset block (i.e., range).'''
args = get_retro_args()
n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample
for chunk_id in range(*block["range"]))))
sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {
"dataset_idx" : sample["dataset_idx"],
"doc_ids" : sample["doc_ids"],
}
# Embed block.
embeddings = embed_block(query_dataset, block, embedder)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
db_dataset, index,
embeddings, block["range"],
sample_map, n_chunks_per_sample)
# Save neighbors.
print_rank_0("save neighbors.")
os.makedirs(os.path.dirname(block["path"]), exist_ok=True)
f = h5py.File(block["path"], "w")
f.create_dataset("neighbors", data=filtered_neighbor_ids)
f.close()
def query_dataset_neighbors(db_dataset, query_dataset,
prefix, neighbor_dir,
index, embedder):
'''Query neighbors of each chunk within a dataset.'''
args = get_retro_args()
def validate(f):
assert f["neighbors"].shape[1] == args.retro_query_num_neighbors_save, \
"neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape),
args.retro_num_neighbors_target,
)
n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank(
neighbor_dir,
len(query_dataset),
args.retro_block_size,
validate=validate,
)
# Query each block.
for block_index, block in enumerate(missing_neighbor_blocks):
if block is not None:
# Progress.
print_rank_0("query '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%." % (
prefix,
block_index,
len(missing_neighbor_blocks),
os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
# Query block neighbors.
query_block_neighbors(db_dataset, query_dataset,
index, embedder,
block)
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def query_pretraining_neighbors():
'''Query pretraining datasets (train & valid).'''
args = get_retro_args()
# Num threads.
faiss.omp_set_num_threads(64)
# Load chunk db dataset.
print_rank_0("load chunk db dataset.")
db_dataset = get_db_merged_train_dataset()
db_dataset.load_doc_tuples()
# Load index.
print_rank_0(" > get index.")
index = get_index()
# Load datasets.
print_rank_0(" > get dataset map.")
query_dataset_map = get_query_dataset_map()
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Query each (i.e., train, valid, test) dataset.
print_rank_0(" > query.")
for prefix, info in query_dataset_map.items():
print_rank_0(" > query '%s' dataset ... %d samples." %
(prefix, len(info["data"])))
query_dataset_neighbors(db_dataset, info["data"],
prefix, info["neighbor_dir"],
index, embedder)
|