SecureLLMSys commited on
Commit
adc8fc7
·
1 Parent(s): dff74c4
Files changed (1) hide show
  1. src/models/Llama.py +4 -4
src/models/Llama.py CHANGED
@@ -17,21 +17,21 @@ class Llama(Model):
17
  api_pos = int(config["api_key_info"]["api_key_use"])
18
  self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
19
  self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
20
- self._model = None # Delayed init
21
  self.terminators = [
22
  self.tokenizer.eos_token_id,
23
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
- if self._model is None:
28
- self._model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
  device_map=self.device,
32
  token=self.hf_token
33
  )
34
- return self._model
35
 
36
  def query(self, msg, max_tokens=128000):
37
  model = self._load_model_if_needed()
 
17
  api_pos = int(config["api_key_info"]["api_key_use"])
18
  self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
19
  self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
20
+ self.model = None # Delayed init
21
  self.terminators = [
22
  self.tokenizer.eos_token_id,
23
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
+ if self.model is None:
28
+ self.model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
  device_map=self.device,
32
  token=self.hf_token
33
  )
34
+ return self.model
35
 
36
  def query(self, msg, max_tokens=128000):
37
  model = self._load_model_if_needed()