Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| import numpy as np | |
| from fairseq.data.data_utils_fast import batch_by_size_fn | |
| from fairseq.data.data_utils_fast import batch_by_size_vec | |
| class TestBatchBySize(unittest.TestCase): | |
| def batch_by_size_baseline( | |
| cls, | |
| indices, | |
| num_tokens_vec, | |
| max_tokens, | |
| max_sentences, | |
| bsz_mult, | |
| ): | |
| """Simple, reliable and slow implementation of batch by size """ | |
| batches = [] | |
| start = 0 | |
| while start < len(indices): | |
| for end in range(start + 1, len(indices) + 1): | |
| max_val = max(num_tokens_vec[pos] for pos in range(start, end)) | |
| sent_count = end - start | |
| num_tokens = max_val * sent_count | |
| overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0 | |
| terminate = overflow or end == len(indices) | |
| if overflow: | |
| sent_count -= 1 | |
| if terminate: | |
| if sent_count > bsz_mult: | |
| sent_count = sent_count - sent_count % bsz_mult | |
| batches.append(indices[start : start + sent_count]) | |
| start = start + sent_count | |
| break | |
| return batches | |
| def _get_error_message( | |
| cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results | |
| ): | |
| return f"""Reference batch_by_size implementation should produce | |
| same output as the baseline method. | |
| Params: | |
| max_sentences={max_sentences}, | |
| max_tokens={max_tokens}, | |
| bsz_mult={bsz_mult}, | |
| num_tokens_vec={num_tokens_vec}, | |
| expected_batches={validation}, | |
| returned_batches={results}""" | |
| def _compare_results( | |
| self, | |
| indices_len, | |
| batch_by_size_impl, | |
| max_sentences, | |
| max_tokens, | |
| bsz_mult, | |
| num_tokens_vec, | |
| ): | |
| indices = np.array(list(range(indices_len))) | |
| validation = self.batch_by_size_baseline( | |
| indices, | |
| num_tokens_vec, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| bsz_mult=bsz_mult, | |
| ) | |
| results = batch_by_size_impl( | |
| indices, | |
| num_tokens_vec, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| bsz_mult=bsz_mult, | |
| ) | |
| error_msg = self._get_error_message( | |
| max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results | |
| ) | |
| self.assertEqual(len(validation), len(results), error_msg) | |
| for first, second in zip(validation, results): | |
| self.assertTrue(np.array_equal(first, second), error_msg) | |
| def _run_compare_with_baseline_sweep(self, batch_by_size_impl): | |
| """Compare reference batch_by_size implementation with batch_by_size_baseline | |
| across a dense grid of hyperparam values""" | |
| MAX_MAX_TOKENS = 10 | |
| NUM_TOKENS_VECS_COUNT = 5 | |
| for indices_len in [10, 11]: # try odd and even len of indices | |
| for max_sentences in range(0, indices_len + 2): | |
| for max_tokens in range(0, MAX_MAX_TOKENS): | |
| for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2): | |
| for _ in range(NUM_TOKENS_VECS_COUNT): | |
| num_tokens_vec = np.random.randint( | |
| 0, max_tokens + 1, size=indices_len | |
| ) | |
| self._compare_results( | |
| indices_len, | |
| batch_by_size_impl, | |
| max_sentences, | |
| max_tokens, | |
| bsz_mult, | |
| num_tokens_vec, | |
| ) | |
| class TestBatchBySizeVec(TestBatchBySize): | |
| def test_compare_with_baseline(self): | |
| self._run_compare_with_baseline_sweep(batch_by_size_vec) | |
| class TestBatchBySizeFn(TestBatchBySize): | |
| def test_compare_with_baseline(self): | |
| def batch_by_size_fn_wrapper( | |
| indices, | |
| num_tokens_vec, | |
| max_tokens, | |
| max_sentences, | |
| bsz_mult, | |
| ): | |
| def num_tokens_fn(idx): | |
| return num_tokens_vec[idx] | |
| return batch_by_size_fn( | |
| indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult | |
| ) | |
| self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper) | |
| if __name__ == "__main__": | |
| unittest.main() | |