Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import tempfile | |
| import unittest | |
| import torch | |
| from fairseq.data.dictionary import Dictionary | |
| from fairseq.models.transformer import TransformerModel | |
| from fairseq.modules import multihead_attention, sinusoidal_positional_embedding | |
| from fairseq.tasks.fairseq_task import LegacyFairseqTask | |
| DEFAULT_TEST_VOCAB_SIZE = 100 | |
| class DummyTask(LegacyFairseqTask): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.dictionary = get_dummy_dictionary() | |
| if getattr(self.args, "ctc", False): | |
| self.dictionary.add_symbol("<ctc_blank>") | |
| self.src_dict = self.dictionary | |
| self.tgt_dict = self.dictionary | |
| def source_dictionary(self): | |
| return self.src_dict | |
| def target_dictionary(self): | |
| return self.dictionary | |
| def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): | |
| dummy_dict = Dictionary() | |
| # add dummy symbol to satisfy vocab size | |
| for id, _ in enumerate(range(vocab_size)): | |
| dummy_dict.add_symbol("{}".format(id), 1000) | |
| return dummy_dict | |
| def get_dummy_task_and_parser(): | |
| """ | |
| Return a dummy task and argument parser, which can be used to | |
| create a model/criterion. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS | |
| ) | |
| DummyTask.add_args(parser) | |
| args = parser.parse_args([]) | |
| task = DummyTask.setup_task(args) | |
| return task, parser | |
| def _test_save_and_load(scripted_module): | |
| with tempfile.NamedTemporaryFile() as f: | |
| scripted_module.save(f.name) | |
| torch.jit.load(f.name) | |
| class TestExportModels(unittest.TestCase): | |
| def test_export_multihead_attention(self): | |
| module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) | |
| scripted = torch.jit.script(module) | |
| _test_save_and_load(scripted) | |
| def test_incremental_state_multihead_attention(self): | |
| module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) | |
| module1 = torch.jit.script(module1) | |
| module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) | |
| module2 = torch.jit.script(module2) | |
| state = {} | |
| state = module1.set_incremental_state(state, "key", {"a": torch.tensor([1])}) | |
| state = module2.set_incremental_state(state, "key", {"a": torch.tensor([2])}) | |
| v1 = module1.get_incremental_state(state, "key")["a"] | |
| v2 = module2.get_incremental_state(state, "key")["a"] | |
| self.assertEqual(v1, 1) | |
| self.assertEqual(v2, 2) | |
| def test_positional_embedding(self): | |
| module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding( | |
| embedding_dim=8, padding_idx=1 | |
| ) | |
| scripted = torch.jit.script(module) | |
| _test_save_and_load(scripted) | |
| def test_export_transformer(self): | |
| task, parser = get_dummy_task_and_parser() | |
| TransformerModel.add_args(parser) | |
| args = parser.parse_args([]) | |
| model = TransformerModel.build_model(args, task) | |
| scripted = torch.jit.script(model) | |
| _test_save_and_load(scripted) | |
| def test_export_transformer_no_token_pos_emb(self): | |
| task, parser = get_dummy_task_and_parser() | |
| TransformerModel.add_args(parser) | |
| args = parser.parse_args([]) | |
| args.no_token_positional_embeddings = True | |
| model = TransformerModel.build_model(args, task) | |
| scripted = torch.jit.script(model) | |
| _test_save_and_load(scripted) | |
| if __name__ == "__main__": | |
| unittest.main() | |