lym0302 commited on
Commit
77dc150
·
1 Parent(s): d3e19f4

load_8bit=True

Browse files
third_party/VideoLLaMA2/videollama2/model/__init__.py CHANGED
@@ -52,7 +52,7 @@ VLLMConfigs = {
52
 
53
 
54
 
55
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
56
  print("00000000000000000000000000: ", device, use_flash_attn)
57
  if 'token' in kwargs:
58
  token = kwargs['token']
@@ -76,8 +76,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
76
  bnb_4bit_quant_type='nf4'
77
  )
78
  else:
79
- # kwargs['torch_dtype'] = torch.float16
80
- kwargs['torch_dtype'] = torch.bfloat16
81
 
82
  if use_flash_attn:
83
  kwargs['attn_implementation'] = 'flash_attention_2'
 
52
 
53
 
54
 
55
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=True, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
56
  print("00000000000000000000000000: ", device, use_flash_attn)
57
  if 'token' in kwargs:
58
  token = kwargs['token']
 
76
  bnb_4bit_quant_type='nf4'
77
  )
78
  else:
79
+ kwargs['torch_dtype'] = torch.float16
80
+ # kwargs['torch_dtype'] = torch.bfloat16
81
 
82
  if use_flash_attn:
83
  kwargs['attn_implementation'] = 'flash_attention_2'