Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	
		Chongruo Wu
		
	commited on
		
		
					Commit 
							
							·
						
						f1020dc
	
1
								Parent(s):
							
							c899f8b
								
Fix a bug related to displaying ce_loss
Browse filesFormer-commit-id: 2fd1438861a0ef29d82b049836c62160402e8bb7
- model/LISA.py +1 -2
    	
        model/LISA.py
    CHANGED
    
    | @@ -306,7 +306,6 @@ class LISAForCausalLM(LlavaLlamaForCausalLM): | |
| 306 |  | 
| 307 | 
             
                    ce_loss = model_output.loss
         | 
| 308 | 
             
                    ce_loss = ce_loss * self.ce_loss_weight
         | 
| 309 | 
            -
                    loss = ce_loss
         | 
| 310 | 
             
                    mask_bce_loss = 0
         | 
| 311 | 
             
                    mask_dice_loss = 0
         | 
| 312 | 
             
                    num_masks = 0
         | 
| @@ -333,7 +332,7 @@ class LISAForCausalLM(LlavaLlamaForCausalLM): | |
| 333 | 
             
                    mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
         | 
| 334 | 
             
                    mask_loss = mask_bce_loss + mask_dice_loss
         | 
| 335 |  | 
| 336 | 
            -
                    loss  | 
| 337 |  | 
| 338 | 
             
                    return {
         | 
| 339 | 
             
                        "loss": loss,
         | 
|  | |
| 306 |  | 
| 307 | 
             
                    ce_loss = model_output.loss
         | 
| 308 | 
             
                    ce_loss = ce_loss * self.ce_loss_weight
         | 
|  | |
| 309 | 
             
                    mask_bce_loss = 0
         | 
| 310 | 
             
                    mask_dice_loss = 0
         | 
| 311 | 
             
                    num_masks = 0
         | 
|  | |
| 332 | 
             
                    mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
         | 
| 333 | 
             
                    mask_loss = mask_bce_loss + mask_dice_loss
         | 
| 334 |  | 
| 335 | 
            +
                    loss = ce_loss + mask_loss
         | 
| 336 |  | 
| 337 | 
             
                    return {
         | 
| 338 | 
             
                        "loss": loss,
         |