SecureLLMSys commited on
Commit
1ca09f6
·
1 Parent(s): 444ccdb
Files changed (1) hide show
  1. src/models/Llama.py +3 -3
src/models/Llama.py CHANGED
@@ -24,15 +24,15 @@ class Llama(Model):
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
- if self._model is None:
28
  model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
  token=self.hf_token,
32
  device_map="auto", # or omit entirely to default to CPU
33
  )
34
- self._model = model
35
- return self._model
36
 
37
  def query(self, msg, max_tokens=128000):
38
  model = self._load_model_if_needed().to("cuda")
 
24
  ]
25
 
26
  def _load_model_if_needed(self):
27
+ if self.model is None:
28
  model = AutoModelForCausalLM.from_pretrained(
29
  self.name,
30
  torch_dtype=torch.bfloat16,
31
  token=self.hf_token,
32
  device_map="auto", # or omit entirely to default to CPU
33
  )
34
+ self.model = model
35
+ return self.model
36
 
37
  def query(self, msg, max_tokens=128000):
38
  model = self._load_model_if_needed().to("cuda")