Spaces:
Running
Running
feat(model): clean way to load on cpu
Browse files- src/dalle_mini/model/modeling.py +10 -2
- tools/train/train.py +0 -4
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -300,6 +300,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 300 |
- added num_params property
|
| 301 |
- config_class replaced to DalleBartConfig
|
| 302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
|
|
|
| 303 |
"""
|
| 304 |
|
| 305 |
config_class = DalleBartConfig
|
|
@@ -311,6 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 311 |
seed: int = 0,
|
| 312 |
dtype: jnp.dtype = jnp.float32,
|
| 313 |
abstract_init: bool = False,
|
|
|
|
| 314 |
**kwargs,
|
| 315 |
):
|
| 316 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
@@ -330,15 +332,21 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 330 |
self.key = PRNGKey(seed)
|
| 331 |
self.dtype = dtype
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# randomly initialized parameters
|
| 334 |
if abstract_init:
|
| 335 |
# init the model weights only abstractly, eval_shape will return a pytree
|
| 336 |
# with the structure as weights but without any actual values, this will just contain
|
| 337 |
# the shape information. Weights need to be loaded later.
|
| 338 |
-
init_fn = partial(
|
| 339 |
random_params = jax.eval_shape(init_fn, self.key)
|
| 340 |
else:
|
| 341 |
-
random_params =
|
| 342 |
|
| 343 |
# save required_params as set
|
| 344 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
|
| 300 |
- added num_params property
|
| 301 |
- config_class replaced to DalleBartConfig
|
| 302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
| 303 |
+
- init weights on CPU
|
| 304 |
"""
|
| 305 |
|
| 306 |
config_class = DalleBartConfig
|
|
|
|
| 312 |
seed: int = 0,
|
| 313 |
dtype: jnp.dtype = jnp.float32,
|
| 314 |
abstract_init: bool = False,
|
| 315 |
+
load_on_cpu: bool = True,
|
| 316 |
**kwargs,
|
| 317 |
):
|
| 318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
|
| 332 |
self.key = PRNGKey(seed)
|
| 333 |
self.dtype = dtype
|
| 334 |
|
| 335 |
+
# init weights on CPU
|
| 336 |
+
if load_on_cpu:
|
| 337 |
+
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
|
| 338 |
+
else:
|
| 339 |
+
init_fn = self.init_weights
|
| 340 |
+
|
| 341 |
# randomly initialized parameters
|
| 342 |
if abstract_init:
|
| 343 |
# init the model weights only abstractly, eval_shape will return a pytree
|
| 344 |
# with the structure as weights but without any actual values, this will just contain
|
| 345 |
# the shape information. Weights need to be loaded later.
|
| 346 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
| 347 |
random_params = jax.eval_shape(init_fn, self.key)
|
| 348 |
else:
|
| 349 |
+
random_params = init_fn(self.key, input_shape)
|
| 350 |
|
| 351 |
# save required_params as set
|
| 352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
tools/train/train.py
CHANGED
|
@@ -702,10 +702,6 @@ def main():
|
|
| 702 |
)
|
| 703 |
return state
|
| 704 |
|
| 705 |
-
# hack: move the inital params to CPU to free up device memory
|
| 706 |
-
# TODO: allow loading weights on CPU in pre-trained model
|
| 707 |
-
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
| 708 |
-
|
| 709 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 710 |
state = pjit(
|
| 711 |
init_state,
|
|
|
|
| 702 |
)
|
| 703 |
return state
|
| 704 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 706 |
state = pjit(
|
| 707 |
init_state,
|