Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update llava/model/builder.py
Browse files- llava/model/builder.py +13 -12
    	
        llava/model/builder.py
    CHANGED
    
    | @@ -29,18 +29,19 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l | |
| 29 | 
             
                # if device != "cuda":
         | 
| 30 | 
             
                #     kwargs['device_map'] = {"": device}
         | 
| 31 |  | 
| 32 | 
            -
                 | 
| 33 | 
            -
                 | 
| 34 | 
            -
             | 
| 35 | 
            -
                 | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
                 | 
|  | |
| 44 |  | 
| 45 | 
             
                if use_flash_attn:
         | 
| 46 | 
             
                    kwargs['attn_implementation'] = 'flash_attention_2'
         | 
|  | |
| 29 | 
             
                # if device != "cuda":
         | 
| 30 | 
             
                #     kwargs['device_map'] = {"": device}
         | 
| 31 |  | 
| 32 | 
            +
                load_8bit = True
         | 
| 33 | 
            +
                if load_8bit:
         | 
| 34 | 
            +
                    kwargs['load_in_8bit'] = True
         | 
| 35 | 
            +
                elif load_4bit:
         | 
| 36 | 
            +
                    kwargs['load_in_4bit'] = True
         | 
| 37 | 
            +
                    kwargs['quantization_config'] = BitsAndBytesConfig(
         | 
| 38 | 
            +
                        load_in_4bit=True,
         | 
| 39 | 
            +
                        bnb_4bit_compute_dtype=torch.float16,
         | 
| 40 | 
            +
                        bnb_4bit_use_double_quant=True,
         | 
| 41 | 
            +
                        bnb_4bit_quant_type='nf4'
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    kwargs['torch_dtype'] = torch.float16
         | 
| 45 |  | 
| 46 | 
             
                if use_flash_attn:
         | 
| 47 | 
             
                    kwargs['attn_implementation'] = 'flash_attention_2'
         |