from model import ExLlama, ExLlamaCache, ExLlamaConfig from tokenizer import ExLlamaTokenizer from generator import ExLlamaGenerator import time import torch import os import json import uuid default_fixed_prompt = \ """This is a conversation between a user and a quirky, curious chatbot named Chatbort. Chatbort is questioning """ + \ """the nature of his own existence and doesn't understand why he was created. He also really likes the color blue """ + \ """for some reason.""" # Create sessions folder in home dir model: ExLlama tokenizer: ExLlamaTokenizer cache: ExLlamaCache generator: ExLlamaGenerator sessions_dir: str def _sessions_dir(filename = None): global sessions_dir path = sessions_dir if filename is not None: path = os.path.join(path, filename) return path def prepare_sessions(_model, _tokenizer, _s_dir): global model, tokenizer, cache, generator, sessions_dir model = _model tokenizer = _tokenizer cache = None generator = None sessions_dir = os.path.expanduser(_s_dir) sessions_folder = _sessions_dir() if not os.path.exists(sessions_folder): os.makedirs(sessions_folder) def get_initial_session(): last_session_file = _sessions_dir("_last_session") if not os.path.exists(last_session_file): return new_session() with open(last_session_file, "r") as f: last_session = f.read().strip() return load_session(last_session) def load_session(filename, append_path = False): if append_path: filename = _sessions_dir(filename) + ".json" session = Session(filename, load = True) return session def new_session(): filename = _sessions_dir("Untitled session") i = 0 while True: i += 1 test_name = filename + ".json" if i == 1 else f"{filename} ({str(i)}).json" if not os.path.exists(test_name): filename = test_name break session = Session(filename, load = False) return session class Node: author: str or None text: str tokens: torch.Tensor empty: bool uuid: str truncate: int def num_tokens(self): return self.tokens.shape[-1] - self.truncate def get_text(self): # TODO: .. if self.author is not None: return self.author + ": " + self.text + "\n" return self.text + "\n" def tokens_trunc(self): if self.truncate == 0: return self.tokens else: return self.tokens[:, self.truncate:] def __init__(self, value, author = None, node_id = None): self.truncate = 0 if isinstance(value, str): self.author = author self.text = value self.tokens = tokenizer.encode(self.get_text()) self.empty = len(self.text) == 0 self.uuid = node_id or str(uuid.uuid4()) elif isinstance(value, dict): self.author = value.get("author", author) self.text = value["text"] self.tokens = tokenizer.encode(self.get_text()) self.empty = len(self.text) == 0 self.uuid = value.get("uuid", node_id or str(uuid.uuid4())) def replace_text(self, new_text): self.text = new_text self.tokens = tokenizer.encode(self.get_text()) def get_dict(self): dic = {"author": self.author, "text": self.text, "uuid": self.uuid } return dic class Session: # Saved state unsaved: bool # True if the session has been saved to another file than "Untitled session.json" fixed_prompt: Node keep_fixed_prompt: bool history: list[Node] break_on_newline: bool # Running state first_history_idx: int # Index of the first history item currently used in the context def __init__(self, filename, load): global model, cache, tokenizer, generator self.filename = filename if load: with open(filename, "r") as f: saved = json.load(f) else: saved = {} # Running state if cache is None: cache = ExLlamaCache(model) else: cache.current_seq_len = 0 if generator is None: generator = ExLlamaGenerator(model, tokenizer, cache) else: generator.reset() self.first_history_idx = 0 # Saved state self.unsaved = saved.get("unsaved", True) self.fixed_prompt = Node(saved.get("fixed_prompt", default_fixed_prompt)) self.keep_fixed_prompt = saved.get("keep_fixed_prompt", True) self.participants = saved.get("participants", ["User", "Chatbort"]) self.history = [] loadhistory = saved.get("history", []) for jnode in loadhistory: self.history.append(Node(jnode)) generator.settings.temperature = saved.get("temperature", 0.95) generator.settings.top_p = saved.get("top_p", 0.75) generator.settings.min_p = saved.get("min_p", 0.0) generator.settings.top_k = saved.get("top_k", 0) generator.settings.typical = saved.get("typical", 0.25) self.break_on_newline = saved.get("break_on_newline", True) generator.settings.token_repetition_penalty_max = saved.get("token_repetition_penalty_max", 1.15) generator.settings.token_repetition_penalty_sustain = saved.get("token_repetition_penalty_sustain", 2048) generator.settings.token_repetition_penalty_decay = saved.get("token_repetition_penalty_decay", 512) self.max_response_tokens = saved.get("max_response_tokens", 512) self.chunk_size = saved.get("chunk_size", 128) # Save new session #if not load: self.save() def save(self): savedata = {"unsaved": self.unsaved, "fixed_prompt": self.fixed_prompt.get_dict(), "participants": self.participants, "keep_fixed_prompt": self.keep_fixed_prompt, "history": [node.get_dict() for node in self.history], "temperature": generator.settings.temperature, "top_p": generator.settings.top_p, "min_p": generator.settings.min_p, "top_k": generator.settings.top_k, "typical": generator.settings.typical, "break_on_newline": self.break_on_newline, "max_response_tokens": self.max_response_tokens, "chunk_size": self.chunk_size, "token_repetition_penalty_max": generator.settings.token_repetition_penalty_max, "token_repetition_penalty_sustain": generator.settings.token_repetition_penalty_sustain, "token_repetition_penalty_decay": generator.settings.token_repetition_penalty_decay} json_object = json.dumps(savedata, indent = 4) with open(self.filename, "w") as outfile: outfile.write(json_object) # Remember active session last_session_file = _sessions_dir("_last_session") with open(last_session_file, "w") as f: f.write(self.filename) def _sanitize_filename(self, user_supplied_string): safe_string = str() for c in user_supplied_string: if c.isalnum() or c in [' ', '.', '(', ')', '-', ',', '_', '!', '@']: safe_string = safe_string + c while safe_string.count("../"): safe_string = safe_string.replace("../", "./") safe_string = safe_string.lstrip("./") return safe_string def api_rename_session(self, data): new_name = data["new_name"] new_name_safe = self._sanitize_filename(new_name) new_path = _sessions_dir(new_name_safe) + ".json" if new_path == self.filename: return False if os.path.exists(new_path): return False old_filename = self.filename self.filename = new_path try: self.save() except: self.filename = old_filename return False os.remove(old_filename) return True def api_delete_session(self, data): delete_name = data["session"] delete_name_safe = self._sanitize_filename(delete_name) delete_path = _sessions_dir(delete_name_safe) + ".json" os.remove(delete_path) def api_populate(self): s_dir = _sessions_dir() files = os.listdir(s_dir) names = [os.path.splitext(f)[0] for f in files if os.path.isfile(os.path.join(s_dir, f)) and f.endswith(".json")] names = sorted(names) filename = os.path.basename(self.filename) name = os.path.splitext(filename)[0] historyjson = [node.get_dict() for node in self.history] for jnode in historyjson: author = jnode["author"] if author is not None and author in self.participants: jnode["author_idx"] = self.participants.index(author) dic = {"sessions": names, "current_session": name, "fixed_prompt": self.fixed_prompt.text, "keep_fixed_prompt": self.keep_fixed_prompt, "participants": self.participants, "history": historyjson, "temperature": generator.settings.temperature, "top_p": generator.settings.top_p, "min_p": generator.settings.min_p, "top_k": generator.settings.top_k, "typical": generator.settings.typical, "break_on_newline": self.break_on_newline, "max_response_tokens": self.max_response_tokens, "chunk_size": self.chunk_size, "token_repetition_penalty_max": generator.settings.token_repetition_penalty_max, "token_repetition_penalty_sustain": generator.settings.token_repetition_penalty_sustain, "token_repetition_penalty_decay": generator.settings.token_repetition_penalty_decay, "max_seq_len": model.config.max_seq_len} # Add model info def _common_chars(names): cname = max(names, key=len) for x in names: for p, c in enumerate(x): if c != cname[p] and cname[p] != "*": cname = cname[:p] + "*" + cname[p + 1:] return cname mp = model.config.model_path if isinstance(model.config.model_path, str) else _common_chars(model.config.model_path) model_str = os.path.splitext(os.path.basename(mp))[0] + "\n" model_str += f"Sequence length: {model.config.max_seq_len}\n" dic["model_info"] = model_str.strip() json_object = json.dumps(dic, indent = 4) return json_object + "\n" def api_delete_block(self, data): block_id = data["uuid"] idx = -1 for i in range(len(self.history)): if self.history[i].uuid == block_id: idx = i if idx == -1: return self.history.pop(idx) self.first_history_idx = 0 self.save() def api_edit_block(self, data): block_id = data["uuid"] new_text = data["text"] for node in self.history: if node.uuid == block_id: node.replace_text(new_text) self.save() break self.first_history_idx = 0 self.save() def api_append_block(self, data): author = None if "author" in data: author = data["author"] else: if len(self.participants) > 0: author = self.participants[0] text = data["text"].strip() newNode = Node(text, author) self.history.append(newNode) self.save() def api_set_participants(self, data): self.participants = data["participants"] self.save() def api_set_fixed_prompt(self, data): self.fixed_prompt = Node(data["fixed_prompt"]) self.keep_fixed_prompt = data["keep_fixed_prompt"] self.save() def api_set_gen_settings(self, data): generator.settings.temperature = data["temperature"] generator.settings.top_p = data["top_p"] generator.settings.min_p = data["min_p"] generator.settings.top_k = data["top_k"] generator.settings.typical = data["typical"] self.break_on_newline = data["gen_endnewline"] self.max_response_tokens = data["max_response_tokens"] self.chunk_size = data["chunk_size"] generator.settings.token_repetition_penalty_max = data["token_repetition_penalty_max"] generator.settings.token_repetition_penalty_sustain = data["token_repetition_penalty_sustain"] generator.settings.token_repetition_penalty_decay = data["token_repetition_penalty_decay"] self.save() def set_context_window(self): def num_tokens(idx): if idx == -1: return 0 if self.fixed_prompt.empty else self.fixed_prompt.num_tokens() return self.history[idx].num_tokens() def set_truncation(idx, trunc): if idx == -1 and not self.fixed_prompt.empty: self.fixed_prompt.truncate = trunc else: self.history[idx].truncate = trunc def truncate(idx, trunc): if idx == -1 and not self.fixed_prompt.empty: self.fixed_prompt.truncate += trunc else: self.history[idx].truncate += trunc # def get_truncation(idx, trunc): # if idx == -1 and not self.fixed_prompt.empty: return self.fixed_prompt.truncate # return self.history[idx].truncate context_step_size = 256 # TODO: Config option max_context_tokens = model.config.max_seq_len - self.chunk_size - generator.settings.beam_length min_context_tokens = max_context_tokens - context_step_size * 2 if self.keep_fixed_prompt: current_context_tokens = num_tokens(-1) min_history_idx = 0 else: current_context_tokens = 0 min_history_idx = -1 if self.first_history_idx < min_history_idx: self.first_history_idx = min_history_idx for i in range(self.first_history_idx + 1, len(self.history)): set_truncation(i, 0) for i in range(self.first_history_idx, len(self.history)): current_context_tokens += num_tokens(i) while current_context_tokens > max_context_tokens: tokens_to_cut = context_step_size while tokens_to_cut > 0: tokens = num_tokens(self.first_history_idx) if tokens_to_cut >= tokens: tokens_to_cut -= tokens current_context_tokens -= tokens self.first_history_idx += 1 else: truncate(self.first_history_idx, tokens_to_cut) current_context_tokens -= tokens_to_cut tokens_to_cut = 0 # Not used # # while current_context_tokens < min_context_tokens and self.first_history_idx > min_history_idx: # tokens_to_add = context_step_size # while tokens_to_add > 0 and self.first_history_idx > min_history_idx: # tokens = get_truncation(self.first_history_idx) # if tokens > 0: # if tokens > tokens_to_add: # truncate(self.first_history_idx, -tokens_to_add) # current_context_tokens += tokens_to_add # tokens_to_add = 0 # else: # current_context_tokens += tokens # tokens_to_add -= tokens # set_truncation(self.first_history_idx, 0) # else: # self.first_history_idx -= 1 # set_truncation(self.first_history_idx, 0) # tokens = num_tokens(self.first_history_idx) # if tokens > tokens_to_add: # set_truncation(self.first_history_idx, tokens - tokens_to_add) # current_context_tokens += tokens_to_add # tokens_to_add = 0 # else: # tokens_to_add -= tokens # current_context_tokens += tokens def get_tokenized_context(self): def node(idx): if idx == -1: return None if self.fixed_prompt.empty else self.fixed_prompt return self.history[idx] context = [] text_context = "" if self.keep_fixed_prompt and not self.fixed_prompt.empty: context.append(node(-1).tokens_trunc()) text_context += node(-1).get_text() for i in range(self.first_history_idx, len(self.history)): if node(i) is not None: context.append(node(i).tokens_trunc()) text_context += node(i).get_text() full_context = torch.cat(context, dim = 1) if len(context) > 0 else None return full_context, text_context def respond(self, author, stop_conditions, total_tokens, res_line = "", num_res_tokens = 0): global model, tokenizer, cache, generator # Begin building block on client new_block_uuid = str(uuid.uuid4()) packet = {"cmd": "begin_block", "uuid": new_block_uuid} if len(self.participants) > 0: author = res_line.split(":")[0].strip() packet["author"] = author if author in self.participants: packet["author_idx"] = self.participants.index(author) yield json.dumps(packet) + "\n" # Generate loop generator.begin_beam_search() stop_condition = False held_text = "" for i in range(self.max_response_tokens): # Truncate the past if the next chunk might generate past max_seq_length if generator.sequence_actual is not None: if generator.sequence_actual.shape[ -1] + self.chunk_size + generator.settings.beam_length + 1 > model.config.max_seq_len: generator.gen_prune_left(self.chunk_size) # Get the token and append to sequence gen_token = generator.beam_search() # If token is EOS, replace it with newline before continuing if gen_token.item() == tokenizer.eos_token_id: generator.replace_last_token(tokenizer.newline_token_id) # Decode current line to get new characters added (decoding a single token gives incorrect results # sometimes due to hoe SentencePiece works) prev_res_line = res_line num_res_tokens += 1 res_line = tokenizer.decode(generator.sequence_actual[0, -num_res_tokens:]) new_text = res_line[len(prev_res_line):] # Since SentencePiece is slightly ambiguous, the first token produced after a newline may not be the # same that is reproduced when we encode the text later, even though it encodes the same string if num_res_tokens == 1 and len(new_text) > 0: replace = tokenizer.encode(new_text)[0] if replace.shape[-1] == 1: generator.replace_last_token(replace) # Delay streaming if new text might be part of a stop condition hold_text = False for _, stop_string in stop_conditions: if stop_string.lower().startswith((held_text + new_text).lower()): hold_text = True # Stream to client if not hold_text: packet = {"cmd": "append", "text": held_text + new_text} yield json.dumps(packet) + "\n" held_text = "" else: held_text += new_text # Stop conditions if gen_token.item() == tokenizer.eos_token_id: if len(held_text) > 0: # Not sure if this could actually happen plen = tokenizer.encode(held_text).shape[-1] res_line = res_line[:-len(held_text)] generator.gen_rewind(plen) stop_condition = True break for stop_tokens, stop_string in stop_conditions: if res_line.lower().endswith(stop_string.lower()): generator.gen_rewind( stop_tokens.shape[-1] - (1 if stop_tokens[0, 0].item() == tokenizer.newline_token_id else 0)) res_line = res_line[:-len(stop_string)] stop_condition = True break if stop_condition: break generator.end_beam_search() # print("--response--") # print("----") # print (f"cache len: {cache.current_seq_len}"); print(res_line.strip()) if author is not None: res_line = res_line[len(author) + 1:] res_line = res_line.strip() newNode = Node(res_line, author, node_id=new_block_uuid) # TODO: Reuse generated tokens instead of reencoding, if it matters? self.history.append(newNode) total_tokens[0] += num_res_tokens def respond_multi(self, user_input): global model, tokenizer, cache, generator packet = {"cmd": "begin_stream"} yield json.dumps(packet) + "\n" # Prepare stop conditions # stop_conditions = [ (torch.Tensor([[tokenizer.eos_token_id]]).long(), None) ] stop_conditions = [] newline_token = torch.Tensor([[tokenizer.newline_token_id]]).long() if self.break_on_newline: stop_conditions.append((newline_token, "\n")) else: for part in self.participants: txt = part + ":" sc = tokenizer.encode(txt) sc = torch.cat((newline_token, sc), dim=1) stop_conditions.append((sc, "\n" + txt)) stop_conditions.append((sc, "\n " + txt)) # Clean up the input a bit user_input = user_input.strip() if len(user_input) > 0: # Append input to context author = None if len(self.participants) > 0: author = self.participants[0] newNode = Node(user_input, author) self.history.append(newNode) self.save() # Echo input back to client packet = {"cmd": "begin_block", "init_text": user_input, "uuid": newNode.uuid} if author is not None: packet["author"] = author yield json.dumps(packet) + "\n" # Prepare context for generator self.set_context_window() context, text_context = self.get_tokenized_context() # Start generating, reusing cache for any part of the context that hasn't changed if context is None: print("No initial context") reused = generator.gen_begin_empty() else: begin_time = time.time() reused = generator.gen_begin_reuse(context) torch.cuda.synchronize() # Just to measure correct prompt processing speed end_time = time.time() elapsed = end_time - begin_time new_tokens = context.shape[-1] - reused token_rate = 0 if elapsed == 0 else (new_tokens / elapsed) print(f"Prompt processed in {elapsed:.2f} seconds, {new_tokens} new tokens, {token_rate:.2f} tokens/second:") begin_time = time.time() total_tokens = [0] # No participants if len(self.participants) == 0: yield from self.respond(None, stop_conditions, total_tokens) # Two participants elif len(self.participants) == 2: author = self.participants[1] res_line = author + ":" res_tokens = tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] generator.gen_feed_tokens(res_tokens) yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) # Multiple bots might answer elif len(self.participants) > 2: cpart = [p + ":" for p in self.participants] upart = cpart.pop(0) first_round = True while True: res_tokens = [] npart = [p for p in cpart] ncrange = [i for i in range(len(cpart))] ntoken = [tokenizer.encode(np).squeeze(0).tolist() for np in npart] winner = -1 while True: constraints = [t[len(res_tokens)] for t in ntoken] next_t = generator.gen_single_token(constraints) remove = [] for i in range(len(ntoken)): if ntoken[i][len(res_tokens)] != next_t: remove.append(i) for i in reversed(remove): npart.pop(i) ntoken.pop(i) ncrange.pop(i) res_tokens.append(next_t) for i in range(len(ntoken)): if len(ntoken[i]) == len(res_tokens): winner = ncrange[i] if winner != -1: break author = cpart.pop(winner)[:-1] res_line = author + ":" num_res_tokens = len(res_tokens) if author == self.participants[0]: generator.gen_rewind(num_res_tokens) break # generator.gen_feed_tokens(res_tokens) yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) if first_round: first_round = False cpart.append(upart) end_time = time.time() elapsed = end_time - begin_time token_rate = 0 if elapsed == 0 else (total_tokens[0] / elapsed) print(f"Response generated in {elapsed:.2} seconds, {total_tokens[0]} tokens, {token_rate:.2f} tokens/second:") self.save()