quantumiracle commited on
Commit
aa68c4e
·
1 Parent(s): ccfcf8d
Files changed (1) hide show
  1. llava/llava_agent.py +9 -3
llava/llava_agent.py CHANGED
@@ -22,10 +22,16 @@ class LLavaAgent:
22
  device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index}
23
  else:
24
  device_map = 'auto'
25
- model_path = os.path.expanduser(model_path)
26
- model_name = get_model_name_from_path(model_path)
 
 
 
 
 
 
27
  tokenizer, model, image_processor, context_len = load_pretrained_model(
28
- model_path, None, model_name, device=self.device, device_map=device_map,
29
  load_8bit=load_8bit, load_4bit=load_4bit)
30
  self.model = model
31
  self.image_processor = image_processor
 
22
  device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index}
23
  else:
24
  device_map = 'auto'
25
+
26
+ # Directly use HF repo if not local
27
+ if os.path.exists(model_path):
28
+ resolved_path = model_path
29
+ else:
30
+ resolved_path = model_path # treat as HF model ID
31
+
32
+ model_name = get_model_name_from_path(resolved_path)
33
  tokenizer, model, image_processor, context_len = load_pretrained_model(
34
+ resolved_path, None, model_name, device=self.device, device_map=device_map,
35
  load_8bit=load_8bit, load_4bit=load_4bit)
36
  self.model = model
37
  self.image_processor = image_processor