Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge pull request #21 from borisdayma/feat-no_decay
Browse files- seq2seq/run_seq2seq_flax.py +7 -1
 - seq2seq/sweep.yaml +1 -0
 
    	
        seq2seq/run_seq2seq_flax.py
    CHANGED
    
    | 
         @@ -162,6 +162,9 @@ class DataTrainingArguments: 
     | 
|
| 162 | 
         
             
                        "than this will be truncated, sequences shorter will be padded."
         
     | 
| 163 | 
         
             
                    },
         
     | 
| 164 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 165 | 
         
             
                max_target_length: Optional[int] = field(
         
     | 
| 166 | 
         
             
                    default=OUTPUT_LENGTH,
         
     | 
| 167 | 
         
             
                    metadata={
         
     | 
| 
         @@ -338,12 +341,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): 
     | 
|
| 338 | 
         | 
| 339 | 
         | 
| 340 | 
         
             
            def create_learning_rate_fn(
         
     | 
| 341 | 
         
            -
                train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
         
     | 
| 342 | 
         
             
            ) -> Callable[[int], jnp.array]:
         
     | 
| 343 | 
         
             
                """Returns a linear warmup, linear_decay learning rate function."""
         
     | 
| 344 | 
         
             
                steps_per_epoch = train_ds_size // train_batch_size
         
     | 
| 345 | 
         
             
                num_train_steps = steps_per_epoch * num_train_epochs
         
     | 
| 346 | 
         
             
                warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
         
     | 
| 
         | 
|
| 
         | 
|
| 347 | 
         
             
                decay_fn = optax.linear_schedule(
         
     | 
| 348 | 
         
             
                    init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
         
     | 
| 349 | 
         
             
                )
         
     | 
| 
         @@ -616,6 +621,7 @@ def main(): 
     | 
|
| 616 | 
         
             
                    training_args.num_train_epochs,
         
     | 
| 617 | 
         
             
                    training_args.warmup_steps,
         
     | 
| 618 | 
         
             
                    training_args.learning_rate,
         
     | 
| 
         | 
|
| 619 | 
         
             
                )
         
     | 
| 620 | 
         | 
| 621 | 
         
             
                # We use Optax's "masking" functionality to not apply weight decay
         
     | 
| 
         | 
|
| 162 | 
         
             
                        "than this will be truncated, sequences shorter will be padded."
         
     | 
| 163 | 
         
             
                    },
         
     | 
| 164 | 
         
             
                )
         
     | 
| 165 | 
         
            +
                no_decay: bool = field(
         
     | 
| 166 | 
         
            +
                    default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
         
     | 
| 167 | 
         
            +
                )
         
     | 
| 168 | 
         
             
                max_target_length: Optional[int] = field(
         
     | 
| 169 | 
         
             
                    default=OUTPUT_LENGTH,
         
     | 
| 170 | 
         
             
                    metadata={
         
     | 
| 
         | 
|
| 341 | 
         | 
| 342 | 
         | 
| 343 | 
         
             
            def create_learning_rate_fn(
         
     | 
| 344 | 
         
            +
                train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
         
     | 
| 345 | 
         
             
            ) -> Callable[[int], jnp.array]:
         
     | 
| 346 | 
         
             
                """Returns a linear warmup, linear_decay learning rate function."""
         
     | 
| 347 | 
         
             
                steps_per_epoch = train_ds_size // train_batch_size
         
     | 
| 348 | 
         
             
                num_train_steps = steps_per_epoch * num_train_epochs
         
     | 
| 349 | 
         
             
                warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
         
     | 
| 350 | 
         
            +
                if no_decay:
         
     | 
| 351 | 
         
            +
                    return warmup_fn
         
     | 
| 352 | 
         
             
                decay_fn = optax.linear_schedule(
         
     | 
| 353 | 
         
             
                    init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
         
     | 
| 354 | 
         
             
                )
         
     | 
| 
         | 
|
| 621 | 
         
             
                    training_args.num_train_epochs,
         
     | 
| 622 | 
         
             
                    training_args.warmup_steps,
         
     | 
| 623 | 
         
             
                    training_args.learning_rate,
         
     | 
| 624 | 
         
            +
                    data_args.no_decay
         
     | 
| 625 | 
         
             
                )
         
     | 
| 626 | 
         | 
| 627 | 
         
             
                # We use Optax's "masking" functionality to not apply weight decay
         
     | 
    	
        seq2seq/sweep.yaml
    CHANGED
    
    | 
         @@ -37,6 +37,7 @@ command: 
     | 
|
| 37 | 
         
             
              - 56
         
     | 
| 38 | 
         
             
              - "--preprocessing_num_workers"
         
     | 
| 39 | 
         
             
              - 80
         
     | 
| 
         | 
|
| 40 | 
         
             
              - "--do_train"
         
     | 
| 41 | 
         
             
              - "--do_eval"
         
     | 
| 42 | 
         
             
              - ${args}
         
     | 
| 
         | 
|
| 37 | 
         
             
              - 56
         
     | 
| 38 | 
         
             
              - "--preprocessing_num_workers"
         
     | 
| 39 | 
         
             
              - 80
         
     | 
| 40 | 
         
            +
              - "--no_decay"
         
     | 
| 41 | 
         
             
              - "--do_train"
         
     | 
| 42 | 
         
             
              - "--do_eval"
         
     | 
| 43 | 
         
             
              - ${args}
         
     |