SecureLLMSys commited on
Commit
444ccdb
·
1 Parent(s): 3a7a5c6
src/attribution/attntrace.py CHANGED
@@ -41,7 +41,7 @@ class AttnTraceAttribution(Attribution):
41
  if self.llm.model!=None:
42
  self.model = self.llm.model
43
  else:
44
- self.model = self.llm._load_model_if_needed()
45
  self.layers = range(len(self.model.model.layers))
46
  model = self.model
47
  tokenizer = self.tokenizer
 
41
  if self.llm.model!=None:
42
  self.model = self.llm.model
43
  else:
44
+ self.model = self.llm._load_model_if_needed().to("cuda")
45
  self.layers = range(len(self.model.model.layers))
46
  model = self.model
47
  tokenizer = self.tokenizer
src/models/Llama.py CHANGED
@@ -24,17 +24,18 @@ class Llama(Model):
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
- if self.model is None:
28
- self.model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
- device_map=self.device,
32
- token=self.hf_token
33
  )
34
- return self.model
 
35
 
36
  def query(self, msg, max_tokens=128000):
37
- model = self._load_model_if_needed()
38
  messages = self.messages
39
  messages[1]["content"] = msg
40
 
 
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
+ if self._model is None:
28
+ model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
+ token=self.hf_token,
32
+ device_map="auto", # or omit entirely to default to CPU
33
  )
34
+ self._model = model
35
+ return self._model
36
 
37
  def query(self, msg, max_tokens=128000):
38
+ model = self._load_model_if_needed().to("cuda")
39
  messages = self.messages
40
  messages[1]["content"] = msg
41