Spaces:
Runtime error
Runtime error
feat: allow abstract_init
Browse files- dalle_mini/model/modeling.py +42 -1
- tools/train/train.py +4 -1
dalle_mini/model/modeling.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
|
| 17 |
import math
|
| 18 |
from functools import partial
|
| 19 |
-
from typing import Optional
|
| 20 |
|
| 21 |
import flax.linen as nn
|
| 22 |
import jax
|
|
@@ -298,10 +298,51 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 298 |
Edits:
|
| 299 |
- added num_params property
|
| 300 |
- config_class replaced to DalleBartConfig
|
|
|
|
| 301 |
"""
|
| 302 |
|
| 303 |
config_class = DalleBartConfig
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
@property
|
| 306 |
def num_params(self):
|
| 307 |
num_params = jax.tree_map(
|
|
|
|
| 16 |
|
| 17 |
import math
|
| 18 |
from functools import partial
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
|
| 21 |
import flax.linen as nn
|
| 22 |
import jax
|
|
|
|
| 298 |
Edits:
|
| 299 |
- added num_params property
|
| 300 |
- config_class replaced to DalleBartConfig
|
| 301 |
+
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
| 302 |
"""
|
| 303 |
|
| 304 |
config_class = DalleBartConfig
|
| 305 |
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
config: DalleBartConfig,
|
| 309 |
+
input_shape: Tuple[int] = (1, 1),
|
| 310 |
+
seed: int = 0,
|
| 311 |
+
dtype: jnp.dtype = jnp.float32,
|
| 312 |
+
abstract_init: bool = False,
|
| 313 |
+
**kwargs,
|
| 314 |
+
):
|
| 315 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 316 |
+
|
| 317 |
+
# adapted from HuggingFace FlaxPreTrainedModel
|
| 318 |
+
if config is None:
|
| 319 |
+
raise ValueError("config cannot be None")
|
| 320 |
+
|
| 321 |
+
if module is None:
|
| 322 |
+
raise ValueError("module cannot be None")
|
| 323 |
+
|
| 324 |
+
# Those are private to be exposed as typed property on derived classes.
|
| 325 |
+
self._config = config
|
| 326 |
+
self._module = module
|
| 327 |
+
|
| 328 |
+
# Those are public as their type is generic to every derived classes.
|
| 329 |
+
self.key = PRNGKey(seed)
|
| 330 |
+
self.dtype = dtype
|
| 331 |
+
|
| 332 |
+
# randomly initialized parameters
|
| 333 |
+
if abstract_init:
|
| 334 |
+
# init the model weights only abstractly, eval_shape will return a pytree
|
| 335 |
+
# with the structure as weights but without any actual values, this will just contain
|
| 336 |
+
# the shape information. Weights need to be loaded later.
|
| 337 |
+
init_fn = partial(self.init_weights, input_shape=input_shape)
|
| 338 |
+
random_params = jax.eval_shape(init_fn, self.key)
|
| 339 |
+
else:
|
| 340 |
+
random_params = self.init_weights(self.key, input_shape)
|
| 341 |
+
|
| 342 |
+
# save required_params as set
|
| 343 |
+
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
| 344 |
+
self.params = random_params
|
| 345 |
+
|
| 346 |
@property
|
| 347 |
def num_params(self):
|
| 348 |
num_params = jax.tree_map(
|
tools/train/train.py
CHANGED
|
@@ -434,7 +434,9 @@ def main():
|
|
| 434 |
artifact_dir = artifact.download()
|
| 435 |
|
| 436 |
# load model
|
| 437 |
-
model = DalleBart.from_pretrained(
|
|
|
|
|
|
|
| 438 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 439 |
print(model.params)
|
| 440 |
|
|
@@ -458,6 +460,7 @@ def main():
|
|
| 458 |
config=config,
|
| 459 |
seed=training_args.seed_model,
|
| 460 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
|
| 461 |
)
|
| 462 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 463 |
print(model.params)
|
|
|
|
| 434 |
artifact_dir = artifact.download()
|
| 435 |
|
| 436 |
# load model
|
| 437 |
+
model = DalleBart.from_pretrained(
|
| 438 |
+
artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
|
| 439 |
+
)
|
| 440 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 441 |
print(model.params)
|
| 442 |
|
|
|
|
| 460 |
config=config,
|
| 461 |
seed=training_args.seed_model,
|
| 462 |
dtype=getattr(jnp, model_args.dtype),
|
| 463 |
+
abstract_init=True,
|
| 464 |
)
|
| 465 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 466 |
print(model.params)
|