Upload dataclass.py with huggingface_hub
Browse files- dataclass.py +39 -11
    	
        dataclass.py
    CHANGED
    
    | @@ -235,11 +235,7 @@ def asdict(obj): | |
| 235 |  | 
| 236 | 
             
            def _asdict_inner(obj):
         | 
| 237 | 
             
                if is_dataclass(obj):
         | 
| 238 | 
            -
                     | 
| 239 | 
            -
                    for field in fields(obj):
         | 
| 240 | 
            -
                        v = getattr(obj, field.name)
         | 
| 241 | 
            -
                        result[field.name] = _asdict_inner(v)
         | 
| 242 | 
            -
                    return result
         | 
| 243 | 
             
                elif isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
         | 
| 244 | 
             
                    return type(obj)(*[_asdict_inner(v) for v in obj])
         | 
| 245 | 
             
                elif isinstance(obj, (list, tuple)):
         | 
| @@ -340,16 +336,36 @@ class Dataclass(metaclass=DataclassMeta): | |
| 340 | 
             
                        if name in kwargs:
         | 
| 341 | 
             
                            raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
         | 
| 342 |  | 
|  | |
|  | |
| 343 | 
             
                    if len(argv) <= len(_init_positional_fields_names):
         | 
| 344 | 
             
                        unexpected_argv = []
         | 
| 345 | 
             
                    else:
         | 
| 346 | 
             
                        unexpected_argv = argv[len(_init_positional_fields_names) :]
         | 
| 347 |  | 
| 348 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 349 |  | 
| 350 | 
             
                    if self.__allow_unexpected_arguments__:
         | 
| 351 | 
            -
                         | 
| 352 | 
            -
             | 
|  | |
|  | |
| 353 |  | 
| 354 | 
             
                    else:
         | 
| 355 | 
             
                        if len(unexpected_argv) > 0:
         | 
| @@ -376,12 +392,12 @@ class Dataclass(metaclass=DataclassMeta): | |
| 376 | 
             
                                f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
         | 
| 377 | 
             
                            )
         | 
| 378 |  | 
|  | |
|  | |
| 379 | 
             
                    for field in fields(self):
         | 
| 380 | 
             
                        if field.name in kwargs:
         | 
| 381 | 
             
                            setattr(self, field.name, kwargs[field.name])
         | 
| 382 | 
             
                        else:
         | 
| 383 | 
            -
                            if field.name in ["_argv", "_kwargs"] and self.__allow_unexpected_arguments__:
         | 
| 384 | 
            -
                                continue
         | 
| 385 | 
             
                            setattr(self, field.name, get_field_default(field))
         | 
| 386 |  | 
| 387 | 
             
                    self.__post_init__()
         | 
| @@ -390,17 +406,29 @@ class Dataclass(metaclass=DataclassMeta): | |
| 390 | 
             
                def __is_dataclass__(self) -> bool:
         | 
| 391 | 
             
                    return True
         | 
| 392 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 393 | 
             
                def __post_init__(self):
         | 
| 394 | 
             
                    """
         | 
| 395 | 
             
                    Post initialization hook.
         | 
| 396 | 
             
                    """
         | 
| 397 | 
             
                    pass
         | 
| 398 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 399 | 
             
                def to_dict(self):
         | 
| 400 | 
             
                    """
         | 
| 401 | 
             
                    Convert to dict.
         | 
| 402 | 
             
                    """
         | 
| 403 | 
            -
                    return  | 
| 404 |  | 
| 405 | 
             
                def __repr__(self) -> str:
         | 
| 406 | 
             
                    """
         | 
|  | |
| 235 |  | 
| 236 | 
             
            def _asdict_inner(obj):
         | 
| 237 | 
             
                if is_dataclass(obj):
         | 
| 238 | 
            +
                    return obj.to_dict()
         | 
|  | |
|  | |
|  | |
|  | |
| 239 | 
             
                elif isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
         | 
| 240 | 
             
                    return type(obj)(*[_asdict_inner(v) for v in obj])
         | 
| 241 | 
             
                elif isinstance(obj, (list, tuple)):
         | 
|  | |
| 336 | 
             
                        if name in kwargs:
         | 
| 337 | 
             
                            raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
         | 
| 338 |  | 
| 339 | 
            +
                    expected_unexpected_argv = kwargs.pop("_argv", None)
         | 
| 340 | 
            +
             | 
| 341 | 
             
                    if len(argv) <= len(_init_positional_fields_names):
         | 
| 342 | 
             
                        unexpected_argv = []
         | 
| 343 | 
             
                    else:
         | 
| 344 | 
             
                        unexpected_argv = argv[len(_init_positional_fields_names) :]
         | 
| 345 |  | 
| 346 | 
            +
                    if expected_unexpected_argv is not None:
         | 
| 347 | 
            +
                        assert (
         | 
| 348 | 
            +
                            len(unexpected_argv) == 0
         | 
| 349 | 
            +
                        ), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
         | 
| 350 | 
            +
                        unexpected_argv = tuple(expected_unexpected_argv)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
         | 
| 353 | 
            +
                    unexpected_kwargs = {
         | 
| 354 | 
            +
                        k: v for k, v in kwargs.items() if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
         | 
| 355 | 
            +
                    }
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    if expected_unexpected_kwargs is not None:
         | 
| 358 | 
            +
                        intersection = set(unexpected_kwargs.keys()) & set(expected_unexpected_kwargs.keys())
         | 
| 359 | 
            +
                        assert (
         | 
| 360 | 
            +
                            len(intersection) == 0
         | 
| 361 | 
            +
                        ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
         | 
| 362 | 
            +
                        unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
         | 
| 363 |  | 
| 364 | 
             
                    if self.__allow_unexpected_arguments__:
         | 
| 365 | 
            +
                        if len(unexpected_argv) > 0:
         | 
| 366 | 
            +
                            kwargs["_argv"] = unexpected_argv
         | 
| 367 | 
            +
                        if len(unexpected_kwargs) > 0:
         | 
| 368 | 
            +
                            kwargs["_kwargs"] = unexpected_kwargs
         | 
| 369 |  | 
| 370 | 
             
                    else:
         | 
| 371 | 
             
                        if len(unexpected_argv) > 0:
         | 
|  | |
| 392 | 
             
                                f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
         | 
| 393 | 
             
                            )
         | 
| 394 |  | 
| 395 | 
            +
                    self.__pre_init__(**kwargs)
         | 
| 396 | 
            +
             | 
| 397 | 
             
                    for field in fields(self):
         | 
| 398 | 
             
                        if field.name in kwargs:
         | 
| 399 | 
             
                            setattr(self, field.name, kwargs[field.name])
         | 
| 400 | 
             
                        else:
         | 
|  | |
|  | |
| 401 | 
             
                            setattr(self, field.name, get_field_default(field))
         | 
| 402 |  | 
| 403 | 
             
                    self.__post_init__()
         | 
|  | |
| 406 | 
             
                def __is_dataclass__(self) -> bool:
         | 
| 407 | 
             
                    return True
         | 
| 408 |  | 
| 409 | 
            +
                def __pre_init__(self, **kwargs):
         | 
| 410 | 
            +
                    """
         | 
| 411 | 
            +
                    Pre initialization hook.
         | 
| 412 | 
            +
                    """
         | 
| 413 | 
            +
                    pass
         | 
| 414 | 
            +
             | 
| 415 | 
             
                def __post_init__(self):
         | 
| 416 | 
             
                    """
         | 
| 417 | 
             
                    Post initialization hook.
         | 
| 418 | 
             
                    """
         | 
| 419 | 
             
                    pass
         | 
| 420 |  | 
| 421 | 
            +
                def _to_raw_dict(self):
         | 
| 422 | 
            +
                    """
         | 
| 423 | 
            +
                    Convert to raw dict
         | 
| 424 | 
            +
                    """
         | 
| 425 | 
            +
                    return {field.name: getattr(self, field.name) for field in fields(self)}
         | 
| 426 | 
            +
             | 
| 427 | 
             
                def to_dict(self):
         | 
| 428 | 
             
                    """
         | 
| 429 | 
             
                    Convert to dict.
         | 
| 430 | 
             
                    """
         | 
| 431 | 
            +
                    return _asdict_inner(self._to_raw_dict())
         | 
| 432 |  | 
| 433 | 
             
                def __repr__(self) -> str:
         | 
| 434 | 
             
                    """
         | 

