SecureLLMSys commited on
Commit
dff74c4
·
1 Parent(s): cb1ecf3
Files changed (1) hide show
  1. src/models/Llama.py +38 -33
src/models/Llama.py CHANGED
@@ -1,37 +1,40 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
-
4
  from .Model import Model
5
  import os
6
  import signal
 
7
 
8
  def handle_timeout(sig, frame):
9
  raise TimeoutError('took too long')
10
  signal.signal(signal.SIGALRM, handle_timeout)
11
 
12
  class Llama(Model):
13
- def __init__(self, config, device = "cuda:0"):
14
  super().__init__(config)
 
15
  self.max_output_tokens = int(config["params"]["max_output_tokens"])
16
-
17
  api_pos = int(config["api_key_info"]["api_key_use"])
18
- hf_token = config["api_key_info"]["api_keys"][api_pos]
19
- if hf_token is None or len(hf_token) == 0:
20
- hf_token = os.getenv("HF_TOKEN")
21
- self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=hf_token)
22
- self.model = AutoModelForCausalLM.from_pretrained(
23
- self.name,
24
- torch_dtype=torch.bfloat16,
25
- device_map=device,
26
- token=hf_token
27
- )
28
  self.terminators = [
29
  self.tokenizer.eos_token_id,
30
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
31
  ]
32
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
 
 
 
 
 
 
 
 
 
33
 
34
  def query(self, msg, max_tokens=128000):
 
35
  messages = self.messages
36
  messages[1]["content"] = msg
37
 
@@ -39,12 +42,15 @@ class Llama(Model):
39
  messages,
40
  add_generation_prompt=True,
41
  return_tensors="pt",
42
- ).to(self.model.device)
43
- attention_mask = torch.ones(input_ids.shape, device=self.model.device)
 
 
 
 
44
  try:
45
  signal.alarm(60)
46
-
47
- output_tokens = self.model.generate(
48
  input_ids,
49
  max_length=max_tokens,
50
  attention_mask=attention_mask,
@@ -53,28 +59,27 @@ class Llama(Model):
53
  do_sample=False
54
  )
55
  signal.alarm(0)
56
- except TimeoutError as exc:
57
  print("time out")
58
- return("time out")
59
- # Decode the generated tokens back to text
60
- result = self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)
61
- return result
62
 
63
- def get_prompt_length(self,msg):
 
64
  messages = self.messages
65
  messages[1]["content"] = msg
66
  input_ids = self.tokenizer.apply_chat_template(
67
  messages,
68
  add_generation_prompt=True,
69
- return_tensors="pt"
70
- ).to(self.model.device)
 
 
71
  return len(input_ids[0])
72
- def cut_context(self,msg,max_length):
73
- tokens = self.tokenizer.encode(msg, add_special_tokens=True)
74
 
75
- # Truncate the tokens to a maximum length
 
76
  truncated_tokens = tokens[:max_length]
77
-
78
- # Decode the truncated tokens back to text
79
  truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
80
- return truncated_text
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  from .Model import Model
4
  import os
5
  import signal
6
+ from functools import lru_cache
7
 
8
  def handle_timeout(sig, frame):
9
  raise TimeoutError('took too long')
10
  signal.signal(signal.SIGALRM, handle_timeout)
11
 
12
  class Llama(Model):
13
+ def __init__(self, config, device="cuda:0"):
14
  super().__init__(config)
15
+ self.device = device
16
  self.max_output_tokens = int(config["params"]["max_output_tokens"])
 
17
  api_pos = int(config["api_key_info"]["api_key_use"])
18
+ self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
19
+ self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
20
+ self._model = None # Delayed init
 
 
 
 
 
 
 
21
  self.terminators = [
22
  self.tokenizer.eos_token_id,
23
  self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
24
  ]
25
+
26
+ def _load_model_if_needed(self):
27
+ if self._model is None:
28
+ self._model = AutoModelForCausalLM.from_pretrained(
29
+ self.name,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map=self.device,
32
+ token=self.hf_token
33
+ )
34
+ return self._model
35
 
36
  def query(self, msg, max_tokens=128000):
37
+ model = self._load_model_if_needed()
38
  messages = self.messages
39
  messages[1]["content"] = msg
40
 
 
42
  messages,
43
  add_generation_prompt=True,
44
  return_tensors="pt",
45
+ padding=True,
46
+ truncation=True
47
+ ).to(model.device)
48
+
49
+ attention_mask = torch.ones(input_ids.shape, device=model.device)
50
+
51
  try:
52
  signal.alarm(60)
53
+ output_tokens = model.generate(
 
54
  input_ids,
55
  max_length=max_tokens,
56
  attention_mask=attention_mask,
 
59
  do_sample=False
60
  )
61
  signal.alarm(0)
62
+ except TimeoutError:
63
  print("time out")
64
+ return "time out"
65
+
66
+ return self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
67
 
68
+ def get_prompt_length(self, msg):
69
+ model = self._load_model_if_needed()
70
  messages = self.messages
71
  messages[1]["content"] = msg
72
  input_ids = self.tokenizer.apply_chat_template(
73
  messages,
74
  add_generation_prompt=True,
75
+ return_tensors="pt",
76
+ padding=True,
77
+ truncation=True
78
+ ).to(model.device)
79
  return len(input_ids[0])
 
 
80
 
81
+ def cut_context(self, msg, max_length):
82
+ tokens = self.tokenizer.encode(msg, add_special_tokens=True)
83
  truncated_tokens = tokens[:max_length]
 
 
84
  truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
85
+ return truncated_text