Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,994 Bytes
f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 adc8fc7 f214f36 dff74c4 444ccdb dff74c4 444ccdb dff74c4 444ccdb f214f36 444ccdb f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .Model import Model
import os
import signal
from functools import lru_cache
def handle_timeout(sig, frame):
raise TimeoutError('took too long')
signal.signal(signal.SIGALRM, handle_timeout)
class Llama(Model):
def __init__(self, config, device="cuda:0"):
super().__init__(config)
self.device = device
self.max_output_tokens = int(config["params"]["max_output_tokens"])
api_pos = int(config["api_key_info"]["api_key_use"])
self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
self.model = None # Delayed init
self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def _load_model_if_needed(self):
if self._model is None:
model = AutoModelForCausalLM.from_pretrained(
self.name,
torch_dtype=torch.bfloat16,
token=self.hf_token,
device_map="auto", # or omit entirely to default to CPU
)
self._model = model
return self._model
def query(self, msg, max_tokens=128000):
model = self._load_model_if_needed().to("cuda")
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
truncation=True
).to(model.device)
attention_mask = torch.ones(input_ids.shape, device=model.device)
try:
signal.alarm(60)
output_tokens = model.generate(
input_ids,
max_length=max_tokens,
attention_mask=attention_mask,
eos_token_id=self.terminators,
top_k=50,
do_sample=False
)
signal.alarm(0)
except TimeoutError:
print("time out")
return "time out"
return self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)
def get_prompt_length(self, msg):
model = self._load_model_if_needed()
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
return len(input_ids[0])
def cut_context(self, msg, max_length):
tokens = self.tokenizer.encode(msg, add_special_tokens=True)
truncated_tokens = tokens[:max_length]
truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
return truncated_text
|