Upload dataclass.py with huggingface_hub
Browse files- dataclass.py +39 -11
dataclass.py
CHANGED
|
@@ -235,11 +235,7 @@ def asdict(obj):
|
|
| 235 |
|
| 236 |
def _asdict_inner(obj):
|
| 237 |
if is_dataclass(obj):
|
| 238 |
-
|
| 239 |
-
for field in fields(obj):
|
| 240 |
-
v = getattr(obj, field.name)
|
| 241 |
-
result[field.name] = _asdict_inner(v)
|
| 242 |
-
return result
|
| 243 |
elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
|
| 244 |
return type(obj)(*[_asdict_inner(v) for v in obj])
|
| 245 |
elif isinstance(obj, (list, tuple)):
|
|
@@ -340,16 +336,36 @@ class Dataclass(metaclass=DataclassMeta):
|
|
| 340 |
if name in kwargs:
|
| 341 |
raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
|
| 342 |
|
|
|
|
|
|
|
| 343 |
if len(argv) <= len(_init_positional_fields_names):
|
| 344 |
unexpected_argv = []
|
| 345 |
else:
|
| 346 |
unexpected_argv = argv[len(_init_positional_fields_names) :]
|
| 347 |
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
if self.__allow_unexpected_arguments__:
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
| 354 |
else:
|
| 355 |
if len(unexpected_argv) > 0:
|
|
@@ -376,12 +392,12 @@ class Dataclass(metaclass=DataclassMeta):
|
|
| 376 |
f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
|
| 377 |
)
|
| 378 |
|
|
|
|
|
|
|
| 379 |
for field in fields(self):
|
| 380 |
if field.name in kwargs:
|
| 381 |
setattr(self, field.name, kwargs[field.name])
|
| 382 |
else:
|
| 383 |
-
if field.name in ["_argv", "_kwargs"] and self.__allow_unexpected_arguments__:
|
| 384 |
-
continue
|
| 385 |
setattr(self, field.name, get_field_default(field))
|
| 386 |
|
| 387 |
self.__post_init__()
|
|
@@ -390,17 +406,29 @@ class Dataclass(metaclass=DataclassMeta):
|
|
| 390 |
def __is_dataclass__(self) -> bool:
|
| 391 |
return True
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
def __post_init__(self):
|
| 394 |
"""
|
| 395 |
Post initialization hook.
|
| 396 |
"""
|
| 397 |
pass
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
def to_dict(self):
|
| 400 |
"""
|
| 401 |
Convert to dict.
|
| 402 |
"""
|
| 403 |
-
return
|
| 404 |
|
| 405 |
def __repr__(self) -> str:
|
| 406 |
"""
|
|
|
|
| 235 |
|
| 236 |
def _asdict_inner(obj):
|
| 237 |
if is_dataclass(obj):
|
| 238 |
+
return obj.to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
|
| 240 |
return type(obj)(*[_asdict_inner(v) for v in obj])
|
| 241 |
elif isinstance(obj, (list, tuple)):
|
|
|
|
| 336 |
if name in kwargs:
|
| 337 |
raise TypeError(f"{self.__class__.__name__} got multiple values for argument '{name}'")
|
| 338 |
|
| 339 |
+
expected_unexpected_argv = kwargs.pop("_argv", None)
|
| 340 |
+
|
| 341 |
if len(argv) <= len(_init_positional_fields_names):
|
| 342 |
unexpected_argv = []
|
| 343 |
else:
|
| 344 |
unexpected_argv = argv[len(_init_positional_fields_names) :]
|
| 345 |
|
| 346 |
+
if expected_unexpected_argv is not None:
|
| 347 |
+
assert (
|
| 348 |
+
len(unexpected_argv) == 0
|
| 349 |
+
), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
|
| 350 |
+
unexpected_argv = tuple(expected_unexpected_argv)
|
| 351 |
+
|
| 352 |
+
expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
|
| 353 |
+
unexpected_kwargs = {
|
| 354 |
+
k: v for k, v in kwargs.items() if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
if expected_unexpected_kwargs is not None:
|
| 358 |
+
intersection = set(unexpected_kwargs.keys()) & set(expected_unexpected_kwargs.keys())
|
| 359 |
+
assert (
|
| 360 |
+
len(intersection) == 0
|
| 361 |
+
), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
|
| 362 |
+
unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
|
| 363 |
|
| 364 |
if self.__allow_unexpected_arguments__:
|
| 365 |
+
if len(unexpected_argv) > 0:
|
| 366 |
+
kwargs["_argv"] = unexpected_argv
|
| 367 |
+
if len(unexpected_kwargs) > 0:
|
| 368 |
+
kwargs["_kwargs"] = unexpected_kwargs
|
| 369 |
|
| 370 |
else:
|
| 371 |
if len(unexpected_argv) > 0:
|
|
|
|
| 392 |
f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
|
| 393 |
)
|
| 394 |
|
| 395 |
+
self.__pre_init__(**kwargs)
|
| 396 |
+
|
| 397 |
for field in fields(self):
|
| 398 |
if field.name in kwargs:
|
| 399 |
setattr(self, field.name, kwargs[field.name])
|
| 400 |
else:
|
|
|
|
|
|
|
| 401 |
setattr(self, field.name, get_field_default(field))
|
| 402 |
|
| 403 |
self.__post_init__()
|
|
|
|
| 406 |
def __is_dataclass__(self) -> bool:
|
| 407 |
return True
|
| 408 |
|
| 409 |
+
def __pre_init__(self, **kwargs):
|
| 410 |
+
"""
|
| 411 |
+
Pre initialization hook.
|
| 412 |
+
"""
|
| 413 |
+
pass
|
| 414 |
+
|
| 415 |
def __post_init__(self):
|
| 416 |
"""
|
| 417 |
Post initialization hook.
|
| 418 |
"""
|
| 419 |
pass
|
| 420 |
|
| 421 |
+
def _to_raw_dict(self):
|
| 422 |
+
"""
|
| 423 |
+
Convert to raw dict
|
| 424 |
+
"""
|
| 425 |
+
return {field.name: getattr(self, field.name) for field in fields(self)}
|
| 426 |
+
|
| 427 |
def to_dict(self):
|
| 428 |
"""
|
| 429 |
Convert to dict.
|
| 430 |
"""
|
| 431 |
+
return _asdict_inner(self._to_raw_dict())
|
| 432 |
|
| 433 |
def __repr__(self) -> str:
|
| 434 |
"""
|