Spaces:
Runtime error
Runtime error
File size: 2,946 Bytes
f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 adc8fc7 f214f36 dff74c4 adc8fc7 dff74c4 adc8fc7 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .Model import Model
import os
import signal
from functools import lru_cache
def handle_timeout(sig, frame):
raise TimeoutError('took too long')
signal.signal(signal.SIGALRM, handle_timeout)
class Llama(Model):
def __init__(self, config, device="cuda:0"):
super().__init__(config)
self.device = device
self.max_output_tokens = int(config["params"]["max_output_tokens"])
api_pos = int(config["api_key_info"]["api_key_use"])
self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
self.model = None # Delayed init
self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def _load_model_if_needed(self):
if self.model is None:
self.model = AutoModelForCausalLM.from_pretrained(
self.name,
torch_dtype=torch.bfloat16,
device_map=self.device,
token=self.hf_token
)
return self.model
def query(self, msg, max_tokens=128000):
model = self._load_model_if_needed()
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
attention_mask = torch.ones(input_ids.shape, device=model.device)
try:
signal.alarm(60)
output_tokens = model.generate(
input_ids,
max_length=max_tokens,
attention_mask=attention_mask,
eos_token_id=self.terminators,
top_k=50,
do_sample=False
)
signal.alarm(0)
except TimeoutError:
print("time out")
return "time out"
return self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)
def get_prompt_length(self, msg):
model = self._load_model_if_needed()
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
return len(input_ids[0])
def cut_context(self, msg, max_length):
tokens = self.tokenizer.encode(msg, add_special_tokens=True)
truncated_tokens = tokens[:max_length]
truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
return truncated_text
|