Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| import torch | |
| from fairseq.data import MonolingualDataset | |
| from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig | |
| from tests import utils as test_utils | |
| class TestLMContextWindow(unittest.TestCase): | |
| def test_eval_dataloader(self): | |
| dictionary = test_utils.dummy_dictionary(10) | |
| assert len(dictionary) == 14 # 4 extra special symbols | |
| assert dictionary.pad() == 1 | |
| dataset = test_utils.TestDataset([ | |
| torch.tensor([4, 5, 6, 7], dtype=torch.long), | |
| torch.tensor([8, 9, 10, 11], dtype=torch.long), | |
| torch.tensor([12, 13], dtype=torch.long), | |
| ]) | |
| dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) | |
| config = LanguageModelingConfig(tokens_per_sample=4) | |
| task = LanguageModelingTask(config, dictionary) | |
| eval_dataloader = task.eval_lm_dataloader( | |
| dataset=dataset, | |
| batch_size=1, | |
| context_window=2, | |
| ) | |
| batch = next(eval_dataloader) | |
| assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] | |
| assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] | |
| batch = next(eval_dataloader) | |
| assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] | |
| assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] | |
| batch = next(eval_dataloader) | |
| assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] | |
| assert batch["target"][0].tolist() == [1, 1, 12, 13] | |
| if __name__ == "__main__": | |
| unittest.main() | |