File size: 7,273 Bytes
144d5a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.data.biencoder_dataset_utils import make_attention_mask
from deepspeed.accelerator import get_accelerator
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().to(get_accelerator().device_name())
query_mask = (batch['token_mask'] < 0.5).to(get_accelerator().device_name())
query_types = batch['token_types'].long().to(get_accelerator().device_name())
query_len = batch['seq_len'].long().to(get_accelerator().device_name())
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
|