Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	feat: handle gradient checkpointing
Browse files- src/dalle_mini/model/modeling.py +2 -2
- tools/train/train.py +23 -1
    	
        src/dalle_mini/model/modeling.py
    CHANGED
    
    | @@ -144,7 +144,7 @@ class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection): | |
| 144 |  | 
| 145 | 
             
                def setup(self):
         | 
| 146 | 
             
                    layer_module = (
         | 
| 147 | 
            -
                        nn.remat(FlaxBartEncoderLayer)
         | 
| 148 | 
             
                        if self.config.gradient_checkpointing
         | 
| 149 | 
             
                        else FlaxBartEncoderLayer
         | 
| 150 | 
             
                    )
         | 
| @@ -211,7 +211,7 @@ class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection): | |
| 211 |  | 
| 212 | 
             
                def setup(self):
         | 
| 213 | 
             
                    layer_module = (
         | 
| 214 | 
            -
                        nn.remat(FlaxBartDecoderLayer)
         | 
| 215 | 
             
                        if self.config.gradient_checkpointing
         | 
| 216 | 
             
                        else FlaxBartDecoderLayer
         | 
| 217 | 
             
                    )
         | 
|  | |
| 144 |  | 
| 145 | 
             
                def setup(self):
         | 
| 146 | 
             
                    layer_module = (
         | 
| 147 | 
            +
                        nn.remat(FlaxBartEncoderLayer, concrete=True)
         | 
| 148 | 
             
                        if self.config.gradient_checkpointing
         | 
| 149 | 
             
                        else FlaxBartEncoderLayer
         | 
| 150 | 
             
                    )
         | 
|  | |
| 211 |  | 
| 212 | 
             
                def setup(self):
         | 
| 213 | 
             
                    layer_module = (
         | 
| 214 | 
            +
                        nn.remat(FlaxBartDecoderLayer, concrete=True)
         | 
| 215 | 
             
                        if self.config.gradient_checkpointing
         | 
| 216 | 
             
                        else FlaxBartDecoderLayer
         | 
| 217 | 
             
                    )
         | 
    	
        tools/train/train.py
    CHANGED
    
    | @@ -18,6 +18,7 @@ Training DALL·E Mini. | |
| 18 | 
             
            Script adapted from run_summarization_flax.py
         | 
| 19 | 
             
            """
         | 
| 20 |  | 
|  | |
| 21 | 
             
            import io
         | 
| 22 | 
             
            import logging
         | 
| 23 | 
             
            import os
         | 
| @@ -531,6 +532,8 @@ def main(): | |
| 531 | 
             
                # Set up our new model config
         | 
| 532 | 
             
                if model_args.config_name:
         | 
| 533 | 
             
                    config = DalleBartConfig.from_pretrained(model_args.config_name)
         | 
|  | |
|  | |
| 534 | 
             
                else:
         | 
| 535 | 
             
                    config = None
         | 
| 536 |  | 
| @@ -553,8 +556,27 @@ def main(): | |
| 553 | 
             
                    )
         | 
| 554 |  | 
| 555 | 
             
                # update model config per training args
         | 
|  | |
|  | |
| 556 | 
             
                model.config.gradient_checkpointing = training_args.gradient_checkpointing
         | 
| 557 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 558 | 
             
                # get model metadata
         | 
| 559 | 
             
                model_metadata = model_args.get_metadata()
         | 
| 560 |  | 
| @@ -967,7 +989,7 @@ def main(): | |
| 967 |  | 
| 968 | 
             
                    def compute_eval_loss(batch):
         | 
| 969 | 
             
                        batch, labels = batch.pop("labels")
         | 
| 970 | 
            -
                        logits =  | 
| 971 | 
             
                        return loss_fn(logits, labels)
         | 
| 972 |  | 
| 973 | 
             
                    # calculate loss independently per dp_device
         | 
|  | |
| 18 | 
             
            Script adapted from run_summarization_flax.py
         | 
| 19 | 
             
            """
         | 
| 20 |  | 
| 21 | 
            +
            import copy
         | 
| 22 | 
             
            import io
         | 
| 23 | 
             
            import logging
         | 
| 24 | 
             
            import os
         | 
|  | |
| 532 | 
             
                # Set up our new model config
         | 
| 533 | 
             
                if model_args.config_name:
         | 
| 534 | 
             
                    config = DalleBartConfig.from_pretrained(model_args.config_name)
         | 
| 535 | 
            +
                    # initializing params with gradient checkpointing create issues
         | 
| 536 | 
            +
                    config.gradient_checkpointing = False
         | 
| 537 | 
             
                else:
         | 
| 538 | 
             
                    config = None
         | 
| 539 |  | 
|  | |
| 556 | 
             
                    )
         | 
| 557 |  | 
| 558 | 
             
                # update model config per training args
         | 
| 559 | 
            +
                # Done after initialization of weights to avoid issues with remat
         | 
| 560 | 
            +
                # This is still considered correctly during training as function is pjitted
         | 
| 561 | 
             
                model.config.gradient_checkpointing = training_args.gradient_checkpointing
         | 
| 562 |  | 
| 563 | 
            +
                # eval model cannot use remat
         | 
| 564 | 
            +
                eval_config = copy.deepcopy(model.config)
         | 
| 565 | 
            +
                eval_config.gradient_checkpointing = False
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                if training_args.gradient_checkpointing:
         | 
| 568 | 
            +
                    eval_model = DalleBart(
         | 
| 569 | 
            +
                        eval_config,
         | 
| 570 | 
            +
                        seed=training_args.seed_model,
         | 
| 571 | 
            +
                        dtype=getattr(jnp, model_args.dtype),
         | 
| 572 | 
            +
                        abstract_init=True,
         | 
| 573 | 
            +
                        load_on_cpu=True,
         | 
| 574 | 
            +
                    )
         | 
| 575 | 
            +
                    del eval_model._params
         | 
| 576 | 
            +
                    eval_fn = eval_model.__call__
         | 
| 577 | 
            +
                else:
         | 
| 578 | 
            +
                    eval_fn = model.__call__
         | 
| 579 | 
            +
             | 
| 580 | 
             
                # get model metadata
         | 
| 581 | 
             
                model_metadata = model_args.get_metadata()
         | 
| 582 |  | 
|  | |
| 989 |  | 
| 990 | 
             
                    def compute_eval_loss(batch):
         | 
| 991 | 
             
                        batch, labels = batch.pop("labels")
         | 
| 992 | 
            +
                        logits = eval_fn(**batch, params=state.params, train=False)[0]
         | 
| 993 | 
             
                        return loss_fn(logits, labels)
         | 
| 994 |  | 
| 995 | 
             
                    # calculate loss independently per dp_device
         | 
