Spaces:
Build error
Build error
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
import time | |
import torch | |
import os | |
import json | |
import uuid | |
default_fixed_prompt = \ | |
"""This is a conversation between a user and a quirky, curious chatbot named Chatbort. Chatbort is questioning """ + \ | |
"""the nature of his own existence and doesn't understand why he was created. He also really likes the color blue """ + \ | |
"""for some reason.""" | |
# Create sessions folder in home dir | |
model: ExLlama | |
tokenizer: ExLlamaTokenizer | |
cache: ExLlamaCache | |
generator: ExLlamaGenerator | |
sessions_dir: str | |
def _sessions_dir(filename = None): | |
global sessions_dir | |
path = sessions_dir | |
if filename is not None: path = os.path.join(path, filename) | |
return path | |
def prepare_sessions(_model, _tokenizer, _s_dir): | |
global model, tokenizer, cache, generator, sessions_dir | |
model = _model | |
tokenizer = _tokenizer | |
cache = None | |
generator = None | |
sessions_dir = os.path.expanduser(_s_dir) | |
sessions_folder = _sessions_dir() | |
if not os.path.exists(sessions_folder): os.makedirs(sessions_folder) | |
def get_initial_session(): | |
last_session_file = _sessions_dir("_last_session") | |
if not os.path.exists(last_session_file): return new_session() | |
with open(last_session_file, "r") as f: | |
last_session = f.read().strip() | |
return load_session(last_session) | |
def load_session(filename, append_path = False): | |
if append_path: filename = _sessions_dir(filename) + ".json" | |
session = Session(filename, load = True) | |
return session | |
def new_session(): | |
filename = _sessions_dir("Untitled session") | |
i = 0 | |
while True: | |
i += 1 | |
test_name = filename + ".json" if i == 1 else f"{filename} ({str(i)}).json" | |
if not os.path.exists(test_name): | |
filename = test_name | |
break | |
session = Session(filename, load = False) | |
return session | |
class Node: | |
author: str or None | |
text: str | |
tokens: torch.Tensor | |
empty: bool | |
uuid: str | |
truncate: int | |
def num_tokens(self): return self.tokens.shape[-1] - self.truncate | |
def get_text(self): | |
# TODO: .. | |
if self.author is not None: return self.author + ": " + self.text + "\n" | |
return self.text + "\n" | |
def tokens_trunc(self): | |
if self.truncate == 0: return self.tokens | |
else: return self.tokens[:, self.truncate:] | |
def __init__(self, value, author = None, node_id = None): | |
self.truncate = 0 | |
if isinstance(value, str): | |
self.author = author | |
self.text = value | |
self.tokens = tokenizer.encode(self.get_text()) | |
self.empty = len(self.text) == 0 | |
self.uuid = node_id or str(uuid.uuid4()) | |
elif isinstance(value, dict): | |
self.author = value.get("author", author) | |
self.text = value["text"] | |
self.tokens = tokenizer.encode(self.get_text()) | |
self.empty = len(self.text) == 0 | |
self.uuid = value.get("uuid", node_id or str(uuid.uuid4())) | |
def replace_text(self, new_text): | |
self.text = new_text | |
self.tokens = tokenizer.encode(self.get_text()) | |
def get_dict(self): | |
dic = {"author": self.author, | |
"text": self.text, | |
"uuid": self.uuid } | |
return dic | |
class Session: | |
# Saved state | |
unsaved: bool # True if the session has been saved to another file than "Untitled session.json" | |
fixed_prompt: Node | |
keep_fixed_prompt: bool | |
history: list[Node] | |
break_on_newline: bool | |
# Running state | |
first_history_idx: int # Index of the first history item currently used in the context | |
def __init__(self, filename, load): | |
global model, cache, tokenizer, generator | |
self.filename = filename | |
if load: | |
with open(filename, "r") as f: | |
saved = json.load(f) | |
else: | |
saved = {} | |
# Running state | |
if cache is None: cache = ExLlamaCache(model) | |
else: cache.current_seq_len = 0 | |
if generator is None: generator = ExLlamaGenerator(model, tokenizer, cache) | |
else: generator.reset() | |
self.first_history_idx = 0 | |
# Saved state | |
self.unsaved = saved.get("unsaved", True) | |
self.fixed_prompt = Node(saved.get("fixed_prompt", default_fixed_prompt)) | |
self.keep_fixed_prompt = saved.get("keep_fixed_prompt", True) | |
self.participants = saved.get("participants", ["User", "Chatbort"]) | |
self.history = [] | |
loadhistory = saved.get("history", []) | |
for jnode in loadhistory: self.history.append(Node(jnode)) | |
generator.settings.temperature = saved.get("temperature", 0.95) | |
generator.settings.top_p = saved.get("top_p", 0.75) | |
generator.settings.min_p = saved.get("min_p", 0.0) | |
generator.settings.top_k = saved.get("top_k", 0) | |
generator.settings.typical = saved.get("typical", 0.25) | |
self.break_on_newline = saved.get("break_on_newline", True) | |
generator.settings.token_repetition_penalty_max = saved.get("token_repetition_penalty_max", 1.15) | |
generator.settings.token_repetition_penalty_sustain = saved.get("token_repetition_penalty_sustain", 2048) | |
generator.settings.token_repetition_penalty_decay = saved.get("token_repetition_penalty_decay", 512) | |
self.max_response_tokens = saved.get("max_response_tokens", 512) | |
self.chunk_size = saved.get("chunk_size", 128) | |
# Save new session | |
#if not load: | |
self.save() | |
def save(self): | |
savedata = {"unsaved": self.unsaved, | |
"fixed_prompt": self.fixed_prompt.get_dict(), | |
"participants": self.participants, | |
"keep_fixed_prompt": self.keep_fixed_prompt, | |
"history": [node.get_dict() for node in self.history], | |
"temperature": generator.settings.temperature, | |
"top_p": generator.settings.top_p, | |
"min_p": generator.settings.min_p, | |
"top_k": generator.settings.top_k, | |
"typical": generator.settings.typical, | |
"break_on_newline": self.break_on_newline, | |
"max_response_tokens": self.max_response_tokens, | |
"chunk_size": self.chunk_size, | |
"token_repetition_penalty_max": generator.settings.token_repetition_penalty_max, | |
"token_repetition_penalty_sustain": generator.settings.token_repetition_penalty_sustain, | |
"token_repetition_penalty_decay": generator.settings.token_repetition_penalty_decay} | |
json_object = json.dumps(savedata, indent = 4) | |
with open(self.filename, "w") as outfile: | |
outfile.write(json_object) | |
# Remember active session | |
last_session_file = _sessions_dir("_last_session") | |
with open(last_session_file, "w") as f: | |
f.write(self.filename) | |
def _sanitize_filename(self, user_supplied_string): | |
safe_string = str() | |
for c in user_supplied_string: | |
if c.isalnum() or c in [' ', '.', '(', ')', '-', ',', '_', '!', '@']: | |
safe_string = safe_string + c | |
while safe_string.count("../"): | |
safe_string = safe_string.replace("../", "./") | |
safe_string = safe_string.lstrip("./") | |
return safe_string | |
def api_rename_session(self, data): | |
new_name = data["new_name"] | |
new_name_safe = self._sanitize_filename(new_name) | |
new_path = _sessions_dir(new_name_safe) + ".json" | |
if new_path == self.filename: return False | |
if os.path.exists(new_path): return False | |
old_filename = self.filename | |
self.filename = new_path | |
try: | |
self.save() | |
except: | |
self.filename = old_filename | |
return False | |
os.remove(old_filename) | |
return True | |
def api_delete_session(self, data): | |
delete_name = data["session"] | |
delete_name_safe = self._sanitize_filename(delete_name) | |
delete_path = _sessions_dir(delete_name_safe) + ".json" | |
os.remove(delete_path) | |
def api_populate(self): | |
s_dir = _sessions_dir() | |
files = os.listdir(s_dir) | |
names = [os.path.splitext(f)[0] for f in files if os.path.isfile(os.path.join(s_dir, f)) and f.endswith(".json")] | |
names = sorted(names) | |
filename = os.path.basename(self.filename) | |
name = os.path.splitext(filename)[0] | |
historyjson = [node.get_dict() for node in self.history] | |
for jnode in historyjson: | |
author = jnode["author"] | |
if author is not None and author in self.participants: | |
jnode["author_idx"] = self.participants.index(author) | |
dic = {"sessions": names, | |
"current_session": name, | |
"fixed_prompt": self.fixed_prompt.text, | |
"keep_fixed_prompt": self.keep_fixed_prompt, | |
"participants": self.participants, | |
"history": historyjson, | |
"temperature": generator.settings.temperature, | |
"top_p": generator.settings.top_p, | |
"min_p": generator.settings.min_p, | |
"top_k": generator.settings.top_k, | |
"typical": generator.settings.typical, | |
"break_on_newline": self.break_on_newline, | |
"max_response_tokens": self.max_response_tokens, | |
"chunk_size": self.chunk_size, | |
"token_repetition_penalty_max": generator.settings.token_repetition_penalty_max, | |
"token_repetition_penalty_sustain": generator.settings.token_repetition_penalty_sustain, | |
"token_repetition_penalty_decay": generator.settings.token_repetition_penalty_decay, | |
"max_seq_len": model.config.max_seq_len} | |
# Add model info | |
def _common_chars(names): | |
cname = max(names, key=len) | |
for x in names: | |
for p, c in enumerate(x): | |
if c != cname[p] and cname[p] != "*": cname = cname[:p] + "*" + cname[p + 1:] | |
return cname | |
mp = model.config.model_path if isinstance(model.config.model_path, str) else _common_chars(model.config.model_path) | |
model_str = os.path.splitext(os.path.basename(mp))[0] + "\n" | |
model_str += f"Sequence length: {model.config.max_seq_len}\n" | |
dic["model_info"] = model_str.strip() | |
json_object = json.dumps(dic, indent = 4) | |
return json_object + "\n" | |
def api_delete_block(self, data): | |
block_id = data["uuid"] | |
idx = -1 | |
for i in range(len(self.history)): | |
if self.history[i].uuid == block_id: | |
idx = i | |
if idx == -1: return | |
self.history.pop(idx) | |
self.first_history_idx = 0 | |
self.save() | |
def api_edit_block(self, data): | |
block_id = data["uuid"] | |
new_text = data["text"] | |
for node in self.history: | |
if node.uuid == block_id: | |
node.replace_text(new_text) | |
self.save() | |
break | |
self.first_history_idx = 0 | |
self.save() | |
def api_append_block(self, data): | |
author = None | |
if "author" in data: | |
author = data["author"] | |
else: | |
if len(self.participants) > 0: | |
author = self.participants[0] | |
text = data["text"].strip() | |
newNode = Node(text, author) | |
self.history.append(newNode) | |
self.save() | |
def api_set_participants(self, data): | |
self.participants = data["participants"] | |
self.save() | |
def api_set_fixed_prompt(self, data): | |
self.fixed_prompt = Node(data["fixed_prompt"]) | |
self.keep_fixed_prompt = data["keep_fixed_prompt"] | |
self.save() | |
def api_set_gen_settings(self, data): | |
generator.settings.temperature = data["temperature"] | |
generator.settings.top_p = data["top_p"] | |
generator.settings.min_p = data["min_p"] | |
generator.settings.top_k = data["top_k"] | |
generator.settings.typical = data["typical"] | |
self.break_on_newline = data["gen_endnewline"] | |
self.max_response_tokens = data["max_response_tokens"] | |
self.chunk_size = data["chunk_size"] | |
generator.settings.token_repetition_penalty_max = data["token_repetition_penalty_max"] | |
generator.settings.token_repetition_penalty_sustain = data["token_repetition_penalty_sustain"] | |
generator.settings.token_repetition_penalty_decay = data["token_repetition_penalty_decay"] | |
self.save() | |
def set_context_window(self): | |
def num_tokens(idx): | |
if idx == -1: return 0 if self.fixed_prompt.empty else self.fixed_prompt.num_tokens() | |
return self.history[idx].num_tokens() | |
def set_truncation(idx, trunc): | |
if idx == -1 and not self.fixed_prompt.empty: self.fixed_prompt.truncate = trunc | |
else: self.history[idx].truncate = trunc | |
def truncate(idx, trunc): | |
if idx == -1 and not self.fixed_prompt.empty: self.fixed_prompt.truncate += trunc | |
else: self.history[idx].truncate += trunc | |
# def get_truncation(idx, trunc): | |
# if idx == -1 and not self.fixed_prompt.empty: return self.fixed_prompt.truncate | |
# return self.history[idx].truncate | |
context_step_size = 256 # TODO: Config option | |
max_context_tokens = model.config.max_seq_len - self.chunk_size - generator.settings.beam_length | |
min_context_tokens = max_context_tokens - context_step_size * 2 | |
if self.keep_fixed_prompt: | |
current_context_tokens = num_tokens(-1) | |
min_history_idx = 0 | |
else: | |
current_context_tokens = 0 | |
min_history_idx = -1 | |
if self.first_history_idx < min_history_idx: self.first_history_idx = min_history_idx | |
for i in range(self.first_history_idx + 1, len(self.history)): | |
set_truncation(i, 0) | |
for i in range(self.first_history_idx, len(self.history)): | |
current_context_tokens += num_tokens(i) | |
while current_context_tokens > max_context_tokens: | |
tokens_to_cut = context_step_size | |
while tokens_to_cut > 0: | |
tokens = num_tokens(self.first_history_idx) | |
if tokens_to_cut >= tokens: | |
tokens_to_cut -= tokens | |
current_context_tokens -= tokens | |
self.first_history_idx += 1 | |
else: | |
truncate(self.first_history_idx, tokens_to_cut) | |
current_context_tokens -= tokens_to_cut | |
tokens_to_cut = 0 | |
# Not used | |
# | |
# while current_context_tokens < min_context_tokens and self.first_history_idx > min_history_idx: | |
# tokens_to_add = context_step_size | |
# while tokens_to_add > 0 and self.first_history_idx > min_history_idx: | |
# tokens = get_truncation(self.first_history_idx) | |
# if tokens > 0: | |
# if tokens > tokens_to_add: | |
# truncate(self.first_history_idx, -tokens_to_add) | |
# current_context_tokens += tokens_to_add | |
# tokens_to_add = 0 | |
# else: | |
# current_context_tokens += tokens | |
# tokens_to_add -= tokens | |
# set_truncation(self.first_history_idx, 0) | |
# else: | |
# self.first_history_idx -= 1 | |
# set_truncation(self.first_history_idx, 0) | |
# tokens = num_tokens(self.first_history_idx) | |
# if tokens > tokens_to_add: | |
# set_truncation(self.first_history_idx, tokens - tokens_to_add) | |
# current_context_tokens += tokens_to_add | |
# tokens_to_add = 0 | |
# else: | |
# tokens_to_add -= tokens | |
# current_context_tokens += tokens | |
def get_tokenized_context(self): | |
def node(idx): | |
if idx == -1: return None if self.fixed_prompt.empty else self.fixed_prompt | |
return self.history[idx] | |
context = [] | |
text_context = "" | |
if self.keep_fixed_prompt and not self.fixed_prompt.empty: | |
context.append(node(-1).tokens_trunc()) | |
text_context += node(-1).get_text() | |
for i in range(self.first_history_idx, len(self.history)): | |
if node(i) is not None: | |
context.append(node(i).tokens_trunc()) | |
text_context += node(i).get_text() | |
full_context = torch.cat(context, dim = 1) if len(context) > 0 else None | |
return full_context, text_context | |
def respond(self, author, stop_conditions, total_tokens, res_line = "", num_res_tokens = 0): | |
global model, tokenizer, cache, generator | |
# Begin building block on client | |
new_block_uuid = str(uuid.uuid4()) | |
packet = {"cmd": "begin_block", | |
"uuid": new_block_uuid} | |
if len(self.participants) > 0: | |
author = res_line.split(":")[0].strip() | |
packet["author"] = author | |
if author in self.participants: | |
packet["author_idx"] = self.participants.index(author) | |
yield json.dumps(packet) + "\n" | |
# Generate loop | |
generator.begin_beam_search() | |
stop_condition = False | |
held_text = "" | |
for i in range(self.max_response_tokens): | |
# Truncate the past if the next chunk might generate past max_seq_length | |
if generator.sequence_actual is not None: | |
if generator.sequence_actual.shape[ | |
-1] + self.chunk_size + generator.settings.beam_length + 1 > model.config.max_seq_len: | |
generator.gen_prune_left(self.chunk_size) | |
# Get the token and append to sequence | |
gen_token = generator.beam_search() | |
# If token is EOS, replace it with newline before continuing | |
if gen_token.item() == tokenizer.eos_token_id: | |
generator.replace_last_token(tokenizer.newline_token_id) | |
# Decode current line to get new characters added (decoding a single token gives incorrect results | |
# sometimes due to hoe SentencePiece works) | |
prev_res_line = res_line | |
num_res_tokens += 1 | |
res_line = tokenizer.decode(generator.sequence_actual[0, -num_res_tokens:]) | |
new_text = res_line[len(prev_res_line):] | |
# Since SentencePiece is slightly ambiguous, the first token produced after a newline may not be the | |
# same that is reproduced when we encode the text later, even though it encodes the same string | |
if num_res_tokens == 1 and len(new_text) > 0: | |
replace = tokenizer.encode(new_text)[0] | |
if replace.shape[-1] == 1: generator.replace_last_token(replace) | |
# Delay streaming if new text might be part of a stop condition | |
hold_text = False | |
for _, stop_string in stop_conditions: | |
if stop_string.lower().startswith((held_text + new_text).lower()): hold_text = True | |
# Stream to client | |
if not hold_text: | |
packet = {"cmd": "append", "text": held_text + new_text} | |
yield json.dumps(packet) + "\n" | |
held_text = "" | |
else: | |
held_text += new_text | |
# Stop conditions | |
if gen_token.item() == tokenizer.eos_token_id: | |
if len(held_text) > 0: # Not sure if this could actually happen | |
plen = tokenizer.encode(held_text).shape[-1] | |
res_line = res_line[:-len(held_text)] | |
generator.gen_rewind(plen) | |
stop_condition = True | |
break | |
for stop_tokens, stop_string in stop_conditions: | |
if res_line.lower().endswith(stop_string.lower()): | |
generator.gen_rewind( | |
stop_tokens.shape[-1] - (1 if stop_tokens[0, 0].item() == tokenizer.newline_token_id else 0)) | |
res_line = res_line[:-len(stop_string)] | |
stop_condition = True | |
break | |
if stop_condition: break | |
generator.end_beam_search() | |
# print("--response--") | |
# print("----") | |
# print (f"cache len: {cache.current_seq_len}"); | |
print(res_line.strip()) | |
if author is not None: | |
res_line = res_line[len(author) + 1:] | |
res_line = res_line.strip() | |
newNode = Node(res_line, author, | |
node_id=new_block_uuid) # TODO: Reuse generated tokens instead of reencoding, if it matters? | |
self.history.append(newNode) | |
total_tokens[0] += num_res_tokens | |
def respond_multi(self, user_input): | |
global model, tokenizer, cache, generator | |
packet = {"cmd": "begin_stream"} | |
yield json.dumps(packet) + "\n" | |
# Prepare stop conditions | |
# stop_conditions = [ (torch.Tensor([[tokenizer.eos_token_id]]).long(), None) ] | |
stop_conditions = [] | |
newline_token = torch.Tensor([[tokenizer.newline_token_id]]).long() | |
if self.break_on_newline: | |
stop_conditions.append((newline_token, "\n")) | |
else: | |
for part in self.participants: | |
txt = part + ":" | |
sc = tokenizer.encode(txt) | |
sc = torch.cat((newline_token, sc), dim=1) | |
stop_conditions.append((sc, "\n" + txt)) | |
stop_conditions.append((sc, "\n " + txt)) | |
# Clean up the input a bit | |
user_input = user_input.strip() | |
if len(user_input) > 0: | |
# Append input to context | |
author = None | |
if len(self.participants) > 0: author = self.participants[0] | |
newNode = Node(user_input, author) | |
self.history.append(newNode) | |
self.save() | |
# Echo input back to client | |
packet = {"cmd": "begin_block", | |
"init_text": user_input, | |
"uuid": newNode.uuid} | |
if author is not None: packet["author"] = author | |
yield json.dumps(packet) + "\n" | |
# Prepare context for generator | |
self.set_context_window() | |
context, text_context = self.get_tokenized_context() | |
# Start generating, reusing cache for any part of the context that hasn't changed | |
if context is None: | |
print("No initial context") | |
reused = generator.gen_begin_empty() | |
else: | |
begin_time = time.time() | |
reused = generator.gen_begin_reuse(context) | |
torch.cuda.synchronize() # Just to measure correct prompt processing speed | |
end_time = time.time() | |
elapsed = end_time - begin_time | |
new_tokens = context.shape[-1] - reused | |
token_rate = 0 if elapsed == 0 else (new_tokens / elapsed) | |
print(f"Prompt processed in {elapsed:.2f} seconds, {new_tokens} new tokens, {token_rate:.2f} tokens/second:") | |
begin_time = time.time() | |
total_tokens = [0] | |
# No participants | |
if len(self.participants) == 0: | |
yield from self.respond(None, stop_conditions, total_tokens) | |
# Two participants | |
elif len(self.participants) == 2: | |
author = self.participants[1] | |
res_line = author + ":" | |
res_tokens = tokenizer.encode(res_line) | |
num_res_tokens = res_tokens.shape[-1] | |
generator.gen_feed_tokens(res_tokens) | |
yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) | |
# Multiple bots might answer | |
elif len(self.participants) > 2: | |
cpart = [p + ":" for p in self.participants] | |
upart = cpart.pop(0) | |
first_round = True | |
while True: | |
res_tokens = [] | |
npart = [p for p in cpart] | |
ncrange = [i for i in range(len(cpart))] | |
ntoken = [tokenizer.encode(np).squeeze(0).tolist() for np in npart] | |
winner = -1 | |
while True: | |
constraints = [t[len(res_tokens)] for t in ntoken] | |
next_t = generator.gen_single_token(constraints) | |
remove = [] | |
for i in range(len(ntoken)): | |
if ntoken[i][len(res_tokens)] != next_t: remove.append(i) | |
for i in reversed(remove): | |
npart.pop(i) | |
ntoken.pop(i) | |
ncrange.pop(i) | |
res_tokens.append(next_t) | |
for i in range(len(ntoken)): | |
if len(ntoken[i]) == len(res_tokens): winner = ncrange[i] | |
if winner != -1: break | |
author = cpart.pop(winner)[:-1] | |
res_line = author + ":" | |
num_res_tokens = len(res_tokens) | |
if author == self.participants[0]: | |
generator.gen_rewind(num_res_tokens) | |
break | |
# generator.gen_feed_tokens(res_tokens) | |
yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) | |
if first_round: | |
first_round = False | |
cpart.append(upart) | |
end_time = time.time() | |
elapsed = end_time - begin_time | |
token_rate = 0 if elapsed == 0 else (total_tokens[0] / elapsed) | |
print(f"Response generated in {elapsed:.2} seconds, {total_tokens[0]} tokens, {token_rate:.2f} tokens/second:") | |
self.save() | |