Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| from argparse import ArgumentParser | |
| from dataclasses import dataclass, field | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.dataclass.utils import gen_parser_from_dataclass | |
| class A(FairseqDataclass): | |
| data: str = field(default="test", metadata={"help": "the data input"}) | |
| num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) | |
| class B(FairseqDataclass): | |
| bar: A = field(default=A()) | |
| foo: int = field(default=0, metadata={"help": "not a bar"}) | |
| class D(FairseqDataclass): | |
| arch: A = field(default=A()) | |
| foo: int = field(default=0, metadata={"help": "not a bar"}) | |
| class C(FairseqDataclass): | |
| data: str = field(default="test", metadata={"help": "root level data input"}) | |
| encoder: D = field(default=D()) | |
| decoder: A = field(default=A()) | |
| lr: int = field(default=0, metadata={"help": "learning rate"}) | |
| class TestDataclassUtils(unittest.TestCase): | |
| def test_argparse_convert_basic(self): | |
| parser = ArgumentParser() | |
| gen_parser_from_dataclass(parser, A(), True) | |
| args = parser.parse_args(["--num-layers", '10', "the/data/path"]) | |
| self.assertEqual(args.num_layers, 10) | |
| self.assertEqual(args.data, "the/data/path") | |
| def test_argparse_recursive(self): | |
| parser = ArgumentParser() | |
| gen_parser_from_dataclass(parser, B(), True) | |
| args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) | |
| self.assertEqual(args.num_layers, 10) | |
| self.assertEqual(args.foo, 10) | |
| self.assertEqual(args.data, "the/data/path") | |
| def test_argparse_recursive_prefixing(self): | |
| self.maxDiff = None | |
| parser = ArgumentParser() | |
| gen_parser_from_dataclass(parser, C(), True, "") | |
| args = parser.parse_args( | |
| [ | |
| "--encoder-arch-data", | |
| "ENCODER_ARCH_DATA", | |
| "--encoder-arch-num-layers", | |
| "10", | |
| "--encoder-foo", | |
| "10", | |
| "--decoder-data", | |
| "DECODER_DATA", | |
| "--decoder-num-layers", | |
| "10", | |
| "--lr", | |
| "10", | |
| "the/data/path", | |
| ] | |
| ) | |
| self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") | |
| self.assertEqual(args.encoder_arch_num_layers, 10) | |
| self.assertEqual(args.encoder_foo, 10) | |
| self.assertEqual(args.decoder_data, "DECODER_DATA") | |
| self.assertEqual(args.decoder_num_layers, 10) | |
| self.assertEqual(args.lr, 10) | |
| self.assertEqual(args.data, "the/data/path") | |
| if __name__ == "__main__": | |
| unittest.main() | |