Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Merge branch 'main' of https://github.com/borisdayma/dalle-mini into main
Browse files
    	
        tools/train/distributed_shampoo.py
    CHANGED
    
    | 
         @@ -1,3 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            # coding=utf-8
         
     | 
| 2 | 
         
             
            # Copyright 2022 The Google Research Authors.
         
     | 
| 3 | 
         
             
            #
         
     | 
| 
         @@ -235,7 +237,7 @@ class GraftingType(enum.IntEnum): 
     | 
|
| 235 | 
         
             
                RMSPROP = 3
         
     | 
| 236 | 
         
             
                RMSPROP_NORMALIZED = 4
         
     | 
| 237 | 
         
             
                SQRT_N = 5
         
     | 
| 238 | 
         
            -
                ADAGRAD_NORMALIZED =  
     | 
| 239 | 
         | 
| 240 | 
         | 
| 241 | 
         
             
            def power_iteration(
         
     | 
| 
         | 
|
| 1 | 
         
            +
            # file from: https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
             
            # coding=utf-8
         
     | 
| 4 | 
         
             
            # Copyright 2022 The Google Research Authors.
         
     | 
| 5 | 
         
             
            #
         
     | 
| 
         | 
|
| 237 | 
         
             
                RMSPROP = 3
         
     | 
| 238 | 
         
             
                RMSPROP_NORMALIZED = 4
         
     | 
| 239 | 
         
             
                SQRT_N = 5
         
     | 
| 240 | 
         
            +
                ADAGRAD_NORMALIZED = 6
         
     | 
| 241 | 
         | 
| 242 | 
         | 
| 243 | 
         
             
            def power_iteration(
         
     |