diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,34 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1b7f0fb8eea3e894694d9ca1541a896102d40376
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,21 @@
+cache/*
+characters/*
+extensions/silero_tts/outputs/*
+extensions/elevenlabs_tts/outputs/*
+logs/*
+models/*
+softprompts/*
+torch-dumps/*
+*pycache*
+*/*pycache*
+*/*/pycache*
+
+settings.json
+img_bot*
+img_me*
+
+!characters/Example.json
+!characters/Example.png
+!models/place-your-models-here.txt
+!softprompts/place-your-softprompts-here.txt
+!torch-dumps/place-your-pt-models-here.txt
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f4ff8c355498049a6a67d373355b53bbbf9997e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: Text Generation Webui Space
+emoji: 🏃
+colorFrom: yellow
+colorTo: purple
+sdk: gradio
+sdk_version: 3.20.1
+app_file: run.py
+pinned: false
+license: mit
+duplicated_from: sahilverma0696/text-generation-webui-space-1
+---
+
+Check out this repo https://github.com/oobabooga/text-generation-webui
diff --git a/api-example-stream.py b/api-example-stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ed420252fdceab73cc26d83a7b87f60981ec95
--- /dev/null
+++ b/api-example-stream.py
@@ -0,0 +1,90 @@
+'''
+
+Contributed by SagsMug. Thank you SagsMug.
+https://github.com/oobabooga/text-generation-webui/pull/175
+
+'''
+
+import asyncio
+import json
+import random
+import string
+
+import websockets
+
+
+def random_hash():
+ letters = string.ascii_lowercase + string.digits
+ return ''.join(random.choice(letters) for i in range(9))
+
+async def run(context):
+ server = "127.0.0.1"
+ params = {
+ 'max_new_tokens': 200,
+ 'do_sample': True,
+ 'temperature': 0.5,
+ 'top_p': 0.9,
+ 'typical_p': 1,
+ 'repetition_penalty': 1.05,
+ 'top_k': 0,
+ 'min_length': 0,
+ 'no_repeat_ngram_size': 0,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'length_penalty': 1,
+ 'early_stopping': False,
+ }
+ session = random_hash()
+
+ async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
+ while content := json.loads(await websocket.recv()):
+ #Python3.10 syntax, replace with if elif on older
+ match content["msg"]:
+ case "send_hash":
+ await websocket.send(json.dumps({
+ "session_hash": session,
+ "fn_index": 7
+ }))
+ case "estimation":
+ pass
+ case "send_data":
+ await websocket.send(json.dumps({
+ "session_hash": session,
+ "fn_index": 7,
+ "data": [
+ context,
+ params['max_new_tokens'],
+ params['do_sample'],
+ params['temperature'],
+ params['top_p'],
+ params['typical_p'],
+ params['repetition_penalty'],
+ params['top_k'],
+ params['min_length'],
+ params['no_repeat_ngram_size'],
+ params['num_beams'],
+ params['penalty_alpha'],
+ params['length_penalty'],
+ params['early_stopping'],
+ ]
+ }))
+ case "process_starts":
+ pass
+ case "process_generating" | "process_completed":
+ yield content["output"]["data"][0]
+ # You can search for your desired end indicator and
+ # stop generation by closing the websocket here
+ if (content["msg"] == "process_completed"):
+ break
+
+prompt = "What I would like to say is the following: "
+
+async def get_result():
+ async for response in run(prompt):
+ # Print intermediate steps
+ print(response)
+
+ # Print final result
+ print(response)
+
+asyncio.run(get_result())
diff --git a/api-example.py b/api-example.py
new file mode 100644
index 0000000000000000000000000000000000000000..0306b7ab8a3fa3d6f57d8474ad74d67f13557b6d
--- /dev/null
+++ b/api-example.py
@@ -0,0 +1,59 @@
+'''
+
+This is an example on how to use the API for oobabooga/text-generation-webui.
+
+Make sure to start the web UI with the following flags:
+
+python server.py --model MODEL --listen --no-stream
+
+Optionally, you can also add the --share flag to generate a public gradio URL,
+allowing you to use the API remotely.
+
+'''
+import requests
+
+# Server address
+server = "127.0.0.1"
+
+# Generation parameters
+# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
+params = {
+ 'max_new_tokens': 200,
+ 'do_sample': True,
+ 'temperature': 0.5,
+ 'top_p': 0.9,
+ 'typical_p': 1,
+ 'repetition_penalty': 1.05,
+ 'top_k': 0,
+ 'min_length': 0,
+ 'no_repeat_ngram_size': 0,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'length_penalty': 1,
+ 'early_stopping': False,
+}
+
+# Input prompt
+prompt = "What I would like to say is the following: "
+
+response = requests.post(f"http://{server}:7860/run/textgen", json={
+ "data": [
+ prompt,
+ params['max_new_tokens'],
+ params['do_sample'],
+ params['temperature'],
+ params['top_p'],
+ params['typical_p'],
+ params['repetition_penalty'],
+ params['top_k'],
+ params['min_length'],
+ params['no_repeat_ngram_size'],
+ params['num_beams'],
+ params['penalty_alpha'],
+ params['length_penalty'],
+ params['early_stopping'],
+ ]
+}).json()
+
+reply = response["data"][0]
+print(reply)
diff --git a/characters/Example.json b/characters/Example.json
new file mode 100644
index 0000000000000000000000000000000000000000..496869c4e6cd643c910fbdf86d748c1c70987020
--- /dev/null
+++ b/characters/Example.json
@@ -0,0 +1,7 @@
+{
+ "char_name": "Chiharu Yamada",
+ "char_persona": "Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.",
+ "char_greeting": "*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*\nHey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!",
+ "world_scenario": "",
+ "example_dialogue": "{{user}}: So how did you get into computer engineering?\n{{char}}: I've always loved tinkering with technology since I was a kid.\n{{user}}: That's really impressive!\n{{char}}: *She chuckles bashfully* Thanks!\n{{user}}: So what do you do when you're not working on computers?\n{{char}}: I love exploring, going out with friends, watching movies, and playing video games.\n{{user}}: What's your favorite type of computer hardware to work with?\n{{char}}: Motherboards, they're like puzzles and the backbone of any system.\n{{user}}: That sounds great!\n{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job."
+}
diff --git a/characters/Example.png b/characters/Example.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7c4e513c4eaa05db1ebb2164956ea0b85d74a75
Binary files /dev/null and b/characters/Example.png differ
diff --git a/convert-to-flexgen.py b/convert-to-flexgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..917f023c3fe395c2e3cbcad11c9cdc6b85ef1e7e
--- /dev/null
+++ b/convert-to-flexgen.py
@@ -0,0 +1,60 @@
+'''
+
+Converts a transformers model to a format compatible with flexgen.
+
+'''
+
+import argparse
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
+args = parser.parse_args()
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ global torch_linear_init_backup
+ global torch_layer_norm_init_backup
+
+ torch_linear_init_backup = torch.nn.Linear.reset_parameters
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+
+ torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+def restore_torch_init():
+ """Rollback the change made by disable_torch_init."""
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
+ setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
+
+if __name__ == '__main__':
+ path = Path(args.MODEL)
+ model_name = path.name
+
+ print(f"Loading {model_name}...")
+ #disable_torch_init()
+ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ #restore_torch_init()
+
+ tokenizer = AutoTokenizer.from_pretrained(path)
+
+ out_folder = Path(f"models/{model_name}-np")
+ if not Path(out_folder).exists():
+ os.mkdir(out_folder)
+
+ print(f"Saving the converted model to {out_folder}...")
+ for name, param in tqdm(list(model.model.named_parameters())):
+ name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
+ param_path = os.path.join(out_folder, name)
+ with open(param_path, "wb") as f:
+ np.save(f, param.cpu().detach().numpy())
diff --git a/convert-to-safetensors.py b/convert-to-safetensors.py
new file mode 100644
index 0000000000000000000000000000000000000000..63baaa9726ab48025d2ba473d029bb3f1153aa3a
--- /dev/null
+++ b/convert-to-safetensors.py
@@ -0,0 +1,38 @@
+'''
+
+Converts a transformers model to safetensors format and shards it.
+
+This makes it faster to load (because of safetensors) and lowers its RAM usage
+while loading (because of sharding).
+
+Based on the original script by 81300:
+
+https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
+
+'''
+
+import argparse
+from pathlib import Path
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
+parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
+parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+args = parser.parse_args()
+
+if __name__ == '__main__':
+ path = Path(args.MODEL)
+ model_name = path.name
+
+ print(f"Loading {model_name}...")
+ model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
+ tokenizer = AutoTokenizer.from_pretrained(path)
+
+ out_folder = args.output or Path(f"models/{model_name}_safetensors")
+ print(f"Saving the converted model to {out_folder} with a maximum shard size of {args.max_shard_size}...")
+ model.save_pretrained(out_folder, max_shard_size=args.max_shard_size, safe_serialization=True)
+ tokenizer.save_pretrained(out_folder)
diff --git a/download-model.py b/download-model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b4623142bf04408515af17c0cb155c8a92d971
--- /dev/null
+++ b/download-model.py
@@ -0,0 +1,179 @@
+'''
+Downloads models from Hugging Face to models/model-name.
+
+Example:
+python download-model.py facebook/opt-1.3b
+
+'''
+
+import argparse
+import base64
+import json
+import multiprocessing
+import re
+import sys
+from pathlib import Path
+
+import requests
+import tqdm
+
+parser = argparse.ArgumentParser()
+parser.add_argument('MODEL', type=str, default=None, nargs='?')
+parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
+parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
+parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
+args = parser.parse_args()
+
+def get_file(args):
+ url = args[0]
+ output_folder = args[1]
+ idx = args[2]
+ tot = args[3]
+
+ print(f"Downloading file {idx} of {tot}...")
+ r = requests.get(url, stream=True)
+ with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
+ total_size = int(r.headers.get('content-length', 0))
+ block_size = 1024
+ t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
+ for data in r.iter_content(block_size):
+ t.update(len(data))
+ f.write(data)
+ t.close()
+
+def sanitize_branch_name(branch_name):
+ pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
+ if pattern.match(branch_name):
+ return branch_name
+ else:
+ raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
+
+def select_model_from_default_options():
+ models = {
+ "Pygmalion 6B original": ("PygmalionAI", "pygmalion-6b", "b8344bb4eb76a437797ad3b19420a13922aaabe1"),
+ "Pygmalion 6B main": ("PygmalionAI", "pygmalion-6b", "main"),
+ "Pygmalion 6B dev": ("PygmalionAI", "pygmalion-6b", "dev"),
+ "Pygmalion 2.7B": ("PygmalionAI", "pygmalion-2.7b", "main"),
+ "Pygmalion 1.3B": ("PygmalionAI", "pygmalion-1.3b", "main"),
+ "Pygmalion 350m": ("PygmalionAI", "pygmalion-350m", "main"),
+ "OPT 6.7b": ("facebook", "opt-6.7b", "main"),
+ "OPT 2.7b": ("facebook", "opt-2.7b", "main"),
+ "OPT 1.3b": ("facebook", "opt-1.3b", "main"),
+ "OPT 350m": ("facebook", "opt-350m", "main"),
+ }
+ choices = {}
+
+ print("Select the model that you want to download:\n")
+ for i,name in enumerate(models):
+ char = chr(ord('A')+i)
+ choices[char] = name
+ print(f"{char}) {name}")
+ char = chr(ord('A')+len(models))
+ print(f"{char}) None of the above")
+
+ print()
+ print("Input> ", end='')
+ choice = input()[0].strip().upper()
+ if choice == char:
+ print("""\nThen type the name of your desired Hugging Face model in the format organization/name.
+
+Examples:
+PygmalionAI/pygmalion-6b
+facebook/opt-1.3b
+""")
+
+ print("Input> ", end='')
+ model = input()
+ branch = "main"
+ else:
+ arr = models[choices[choice]]
+ model = f"{arr[0]}/{arr[1]}"
+ branch = arr[2]
+
+ return model, branch
+
+def get_download_links_from_huggingface(model, branch):
+ base = "https://huggingface.co"
+ page = f"/api/models/{model}/tree/{branch}"
+ cursor = b""
+
+ links = []
+ classifications = []
+ has_pytorch = False
+ has_safetensors = False
+ while True:
+ url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
+ r = requests.get(url)
+ r.raise_for_status()
+ content = r.content
+
+ dict = json.loads(content)
+ if len(dict) == 0:
+ break
+
+ for i in range(len(dict)):
+ fname = dict[i]['path']
+
+ is_pytorch = re.match("pytorch_model.*\.bin", fname)
+ is_safetensors = re.match("model.*\.safetensors", fname)
+ is_tokenizer = re.match("tokenizer.*\.model", fname)
+ is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
+
+ if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
+ if is_text:
+ links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
+ classifications.append('text')
+ continue
+ if not args.text_only:
+ links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
+ if is_safetensors:
+ has_safetensors = True
+ classifications.append('safetensors')
+ elif is_pytorch:
+ has_pytorch = True
+ classifications.append('pytorch')
+
+ cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
+ cursor = base64.b64encode(cursor)
+ cursor = cursor.replace(b'=', b'%3D')
+
+ # If both pytorch and safetensors are available, download safetensors only
+ if has_pytorch and has_safetensors:
+ for i in range(len(classifications)-1, -1, -1):
+ if classifications[i] == 'pytorch':
+ links.pop(i)
+
+ return links
+
+if __name__ == '__main__':
+ model = args.MODEL
+ branch = args.branch
+ if model is None:
+ model, branch = select_model_from_default_options()
+ else:
+ if model[-1] == '/':
+ model = model[:-1]
+ branch = args.branch
+ if branch is None:
+ branch = "main"
+ else:
+ try:
+ branch = sanitize_branch_name(branch)
+ except ValueError as err_branch:
+ print(f"Error: {err_branch}")
+ sys.exit()
+ if branch != 'main':
+ output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
+ else:
+ output_folder = Path("models") / model.split('/')[-1]
+ if not output_folder.exists():
+ output_folder.mkdir()
+
+ links = get_download_links_from_huggingface(model, branch)
+
+ # Downloading the files
+ print(f"Downloading the model to {output_folder}")
+ pool = multiprocessing.Pool(processes=args.threads)
+ results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
+ pool.close()
+ pool.join()
diff --git a/extensions/character_bias/script.py b/extensions/character_bias/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b38c0edcb38512f2472937578a363343a4468c
--- /dev/null
+++ b/extensions/character_bias/script.py
@@ -0,0 +1,42 @@
+import gradio as gr
+
+params = {
+ "activate": True,
+ "bias string": " *I am so happy*",
+}
+
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+
+ return string
+
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+
+ return string
+
+def bot_prefix_modifier(string):
+ """
+ This function is only applied in chat mode. It modifies
+ the prefix text for the Bot and can be used to bias its
+ behavior.
+ """
+
+ if params['activate'] == True:
+ return f'{string} {params["bias string"].strip()} '
+ else:
+ return string
+
+def ui():
+ # Gradio elements
+ activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
+ string = gr.Textbox(value=params["bias string"], label='Character bias')
+
+ # Event functions to update the parameters in the backend
+ string.change(lambda x: params.update({"bias string": x}), string, None)
+ activate.change(lambda x: params.update({"activate": x}), activate, None)
diff --git a/extensions/elevenlabs_tts/requirements.txt b/extensions/elevenlabs_tts/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8ec07a8a7fcf02ca48cc00520e66fcb58c447393
--- /dev/null
+++ b/extensions/elevenlabs_tts/requirements.txt
@@ -0,0 +1,3 @@
+elevenlabslib
+soundfile
+sounddevice
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d61efc6aa77bc2377c435eefe4cf623b588168
--- /dev/null
+++ b/extensions/elevenlabs_tts/script.py
@@ -0,0 +1,113 @@
+from pathlib import Path
+
+import gradio as gr
+from elevenlabslib import *
+from elevenlabslib.helpers import *
+
+params = {
+ 'activate': True,
+ 'api_key': '12345',
+ 'selected_voice': 'None',
+}
+
+initial_voice = ['None']
+wav_idx = 0
+user = ElevenLabsUser(params['api_key'])
+user_info = None
+
+
+# Check if the API is valid and refresh the UI accordingly.
+def check_valid_api():
+
+ global user, user_info, params
+
+ user = ElevenLabsUser(params['api_key'])
+ user_info = user._get_subscription_data()
+ print('checking api')
+ if params['activate'] == False:
+ return gr.update(value='Disconnected')
+ elif user_info is None:
+ print('Incorrect API Key')
+ return gr.update(value='Disconnected')
+ else:
+ print('Got an API Key!')
+ return gr.update(value='Connected')
+
+# Once the API is verified, get the available voices and update the dropdown list
+def refresh_voices():
+
+ global user, user_info
+
+ your_voices = [None]
+ if user_info is not None:
+ for voice in user.get_available_voices():
+ your_voices.append(voice.initialName)
+ return gr.Dropdown.update(choices=your_voices)
+ else:
+ return
+
+def remove_surrounded_chars(string):
+ new_string = ""
+ in_star = False
+ for char in string:
+ if char == '*':
+ in_star = not in_star
+ elif not in_star:
+ new_string += char
+ return new_string
+
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+
+ return string
+
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+
+ global params, wav_idx, user, user_info
+
+ if params['activate'] == False:
+ return string
+ elif user_info == None:
+ return string
+
+ string = remove_surrounded_chars(string)
+ string = string.replace('"', '')
+ string = string.replace('“', '')
+ string = string.replace('\n', ' ')
+ string = string.strip()
+
+ if string == '':
+ string = 'empty reply, try regenerating'
+
+ output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
+ voice = user.get_voices_by_name(params['selected_voice'])[0]
+ audio_data = voice.generate_audio_bytes(string)
+ save_bytes_to_path(Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'), audio_data)
+
+ string = f' '
+ wav_idx += 1
+ return string
+
+def ui():
+
+ # Gradio elements
+ with gr.Row():
+ activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
+ connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
+ voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
+ with gr.Row():
+ api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
+ connect = gr.Button(value='Connect')
+
+ # Event functions to update the parameters in the backend
+ activate.change(lambda x: params.update({'activate': x}), activate, None)
+ voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
+ api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
+ connect.click(check_valid_api, [], connection_status)
+ connect.click(refresh_voices, [], voice)
diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2d7cf988734a7ab0966d047ff3d31ba58324b7
--- /dev/null
+++ b/extensions/gallery/script.py
@@ -0,0 +1,82 @@
+from pathlib import Path
+
+import gradio as gr
+
+from modules.html_generator import get_image_cache
+
+
+def generate_html():
+ css = """
+ .character-gallery {
+ margin: 1rem 0;
+ display: grid;
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
+ grid-column-gap: 0.4rem;
+ grid-row-gap: 1.2rem;
+ }
+
+ .character-container {
+ cursor: pointer;
+ text-align: center;
+ position: relative;
+ opacity: 0.85;
+ }
+
+ .character-container:hover {
+ opacity: 1;
+ }
+
+ .character-container .placeholder, .character-container img {
+ width: 150px;
+ height: 200px;
+ background-color: gray;
+ object-fit: cover;
+ margin: 0 auto;
+ border-radius: 1rem;
+ border: 3px solid white;
+ box-shadow: 3px 3px 6px 0px rgb(0 0 0 / 50%);
+ }
+
+ .character-name {
+ margin-top: 0.3rem;
+ display: block;
+ font-size: 1.2rem;
+ font-weight: 600;
+ overflow-wrap: anywhere;
+ }
+ """
+
+ container_html = f'
'
+
+ # Iterate through files in image folder
+ for file in sorted(Path("characters").glob("*")):
+ if file.name.endswith(".json"):
+ character = file.name.replace(".json", "")
+ container_html += f'
'
+ image_html = "
"
+
+ for i in [
+ f"characters/{character}.png",
+ f"characters/{character}.jpg",
+ f"characters/{character}.jpeg",
+ ]:
+
+ path = Path(i)
+ if path.exists():
+ try:
+ image_html = f'
'
+ break
+ except:
+ continue
+
+ container_html += f'{image_html}
{character} '
+ container_html += "
"
+
+ container_html += "
"
+ return container_html
+
+def ui():
+ with gr.Accordion("Character gallery"):
+ update = gr.Button("Refresh")
+ gallery = gr.HTML(value=generate_html())
+ update.click(generate_html, [], gallery)
diff --git a/extensions/google_translate/requirements.txt b/extensions/google_translate/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..554a00df62818f96ba7d396ae39d8e58efbe9bfe
--- /dev/null
+++ b/extensions/google_translate/requirements.txt
@@ -0,0 +1 @@
+deep-translator==1.9.2
diff --git a/extensions/google_translate/script.py b/extensions/google_translate/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..68bc54b293086bed1a070a310d276060ee939d44
--- /dev/null
+++ b/extensions/google_translate/script.py
@@ -0,0 +1,42 @@
+import gradio as gr
+from deep_translator import GoogleTranslator
+
+params = {
+ "language string": "ja",
+}
+
+language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
+
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+
+ return GoogleTranslator(source=params['language string'], target='en').translate(string)
+
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+
+ return GoogleTranslator(source='en', target=params['language string']).translate(string)
+
+def bot_prefix_modifier(string):
+ """
+ This function is only applied in chat mode. It modifies
+ the prefix text for the Bot and can be used to bias its
+ behavior.
+ """
+
+ return string
+
+def ui():
+ # Finding the language name from the language code to use as the default value
+ language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
+
+ # Gradio elements
+ language = gr.Dropdown(value=language_name, choices=[k for k in language_codes], label='Language')
+
+ # Event functions to update the parameters in the backend
+ language.change(lambda x: params.update({"language string": language_codes[x]}), language, None)
diff --git a/extensions/llama_prompts/script.py b/extensions/llama_prompts/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..22c96f7c2d6763213a728d77ee6666496d9c4aa3
--- /dev/null
+++ b/extensions/llama_prompts/script.py
@@ -0,0 +1,18 @@
+import gradio as gr
+import modules.shared as shared
+import pandas as pd
+
+df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
+
+def get_prompt_by_name(name):
+ if name == 'None':
+ return ''
+ else:
+ return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
+
+def ui():
+ if not shared.args.chat or shared.args.cai_chat:
+ choices = ['None'] + list(df['Prompt name'])
+
+ prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
+ prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c356329a51edf026f7223a0ee7e5427d8751ce
--- /dev/null
+++ b/extensions/send_pictures/script.py
@@ -0,0 +1,46 @@
+import base64
+from io import BytesIO
+
+import gradio as gr
+import torch
+from transformers import BlipForConditionalGeneration, BlipProcessor
+
+import modules.chat as chat
+import modules.shared as shared
+
+# If 'state' is True, will hijack the next chat generation with
+# custom input text given by 'value' in the format [text, visible_text]
+input_hijack = {
+ 'state': False,
+ 'value': ["", ""]
+}
+
+processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
+model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
+
+def caption_image(raw_image):
+ inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
+ out = model.generate(**inputs, max_new_tokens=100)
+ return processor.decode(out[0], skip_special_tokens=True)
+
+def generate_chat_picture(picture, name1, name2):
+ text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
+ buffer = BytesIO()
+ picture.save(buffer, format="JPEG")
+ img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ visible_text = f' '
+ return text, visible_text
+
+def ui():
+ picture_select = gr.Image(label='Send a picture', type='pil')
+
+ function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
+
+ # Prepare the hijack with custom inputs
+ picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
+
+ # Call the generation function
+ picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+
+ # Clear the picture from the upload field
+ picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
diff --git a/extensions/silero_tts/requirements.txt b/extensions/silero_tts/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f2f0bff55a862de8e643496d90c01713785801a2
--- /dev/null
+++ b/extensions/silero_tts/requirements.txt
@@ -0,0 +1,6 @@
+ipython
+omegaconf
+pydub
+PyYAML
+torch
+torchaudio
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
new file mode 100644
index 0000000000000000000000000000000000000000..f611dc27b7480cd357b77c0c407fcc2bd6df2679
--- /dev/null
+++ b/extensions/silero_tts/script.py
@@ -0,0 +1,169 @@
+import time
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+import modules.chat as chat
+import modules.shared as shared
+
+torch._C._jit_set_profiling_mode(False)
+
+params = {
+ 'activate': True,
+ 'speaker': 'en_56',
+ 'language': 'en',
+ 'model_id': 'v3_en',
+ 'sample_rate': 48000,
+ 'device': 'cpu',
+ 'show_text': False,
+ 'autoplay': True,
+ 'voice_pitch': 'medium',
+ 'voice_speed': 'medium',
+}
+
+current_params = params.copy()
+voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
+voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
+voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
+
+# Used for making text xml compatible, needed for voice pitch and speed control
+table = str.maketrans({
+ "<": "<",
+ ">": ">",
+ "&": "&",
+ "'": "'",
+ '"': """,
+})
+
+def xmlesc(txt):
+ return txt.translate(table)
+
+def load_model():
+ model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
+ model.to(params['device'])
+ return model
+model = load_model()
+
+def remove_surrounded_chars(string):
+ new_string = ""
+ in_star = False
+ for char in string:
+ if char == '*':
+ in_star = not in_star
+ elif not in_star:
+ new_string += char
+ return new_string
+
+def remove_tts_from_history(name1, name2):
+ for i, entry in enumerate(shared.history['internal']):
+ shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
+ return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def toggle_text_in_history(name1, name2):
+ for i, entry in enumerate(shared.history['visible']):
+ visible_reply = entry[1]
+ if visible_reply.startswith('')[0]} \n\n{reply}"]
+ else:
+ shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('')[0]}"]
+ return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def input_modifier(string):
+ """
+ This function is applied to your text inputs before
+ they are fed into the model.
+ """
+
+ # Remove autoplay from the last reply
+ if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0:
+ shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
+
+ shared.processing_message = "*Is recording a voice message...*"
+ return string
+
+def output_modifier(string):
+ """
+ This function is applied to the model outputs.
+ """
+
+ global model, current_params
+
+ for i in params:
+ if params[i] != current_params[i]:
+ model = load_model()
+ current_params = params.copy()
+ break
+
+ if params['activate'] == False:
+ return string
+
+ original_string = string
+ string = remove_surrounded_chars(string)
+ string = string.replace('"', '')
+ string = string.replace('“', '')
+ string = string.replace('\n', ' ')
+ string = string.strip()
+
+ if string == '':
+ string = '*Empty reply, try regenerating*'
+ else:
+ output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
+ prosody = ''.format(params['voice_speed'], params['voice_pitch'])
+ silero_input = f'{prosody}{xmlesc(string)} '
+ model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
+
+ autoplay = 'autoplay' if params['autoplay'] else ''
+ string = f' '
+ if params['show_text']:
+ string += f'\n\n{original_string}'
+
+ shared.processing_message = "*Is typing...*"
+ return string
+
+def bot_prefix_modifier(string):
+ """
+ This function is only applied in chat mode. It modifies
+ the prefix text for the Bot and can be used to bias its
+ behavior.
+ """
+
+ return string
+
+def ui():
+ # Gradio elements
+ with gr.Accordion("Silero TTS"):
+ with gr.Row():
+ activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
+ autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
+ show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
+ voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
+ with gr.Row():
+ v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
+ v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
+ with gr.Row():
+ convert = gr.Button('Permanently replace audios with the message texts')
+ convert_cancel = gr.Button('Cancel', visible=False)
+ convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
+
+ # Convert history with confirmation
+ convert_arr = [convert_confirm, convert, convert_cancel]
+ convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
+ convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+ convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+ convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+ convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+
+ # Toggle message text in history
+ show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
+ show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+ show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+
+ # Event functions to update the parameters in the backend
+ activate.change(lambda x: params.update({"activate": x}), activate, None)
+ autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
+ voice.change(lambda x: params.update({"speaker": x}), voice, None)
+ v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
+ v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
diff --git a/models/place-your-models-here.txt b/models/place-your-models-here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2723490bbe214e351634ca4054f74a0b5334b28
--- /dev/null
+++ b/modules/GPTQ_loader.py
@@ -0,0 +1,71 @@
+import sys
+from pathlib import Path
+
+import accelerate
+import torch
+
+import modules.shared as shared
+
+sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
+import llama
+import opt
+
+
+def load_quantized(model_name):
+ if not shared.args.gptq_model_type:
+ # Try to determine model type from model name
+ model_type = model_name.split('-')[0].lower()
+ if model_type not in ('llama', 'opt'):
+ print("Can't determine model type from model name. Please specify it manually using --gptq-model-type "
+ "argument")
+ exit()
+ else:
+ model_type = shared.args.gptq_model_type.lower()
+
+ if model_type == 'llama':
+ load_quant = llama.load_quant
+ elif model_type == 'opt':
+ load_quant = opt.load_quant
+ else:
+ print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
+ exit()
+
+ path_to_model = Path(f'models/{model_name}')
+ if path_to_model.name.lower().startswith('llama-7b'):
+ pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt'
+ elif path_to_model.name.lower().startswith('llama-13b'):
+ pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt'
+ elif path_to_model.name.lower().startswith('llama-30b'):
+ pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt'
+ elif path_to_model.name.lower().startswith('llama-65b'):
+ pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt'
+ else:
+ pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'
+
+ # Try to find the .pt both in models/ and in the subfolder
+ pt_path = None
+ for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+ if path.exists():
+ pt_path = path
+
+ if not pt_path:
+ print(f"Could not find {pt_model}, exiting...")
+ exit()
+
+ model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
+
+ # Multiple GPUs or GPU+CPU
+ if shared.args.gpu_memory:
+ max_memory = {}
+ for i in range(len(shared.args.gpu_memory)):
+ max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
+ max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
+
+ device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
+ model = accelerate.dispatch_model(model, device_map=device_map)
+
+ # Single GPU
+ else:
+ model = model.to(torch.device('cuda:0'))
+
+ return model
diff --git a/modules/RWKV.py b/modules/RWKV.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf8937ad37944c0cebeeb8e0891bec1474724ea
--- /dev/null
+++ b/modules/RWKV.py
@@ -0,0 +1,74 @@
+import os
+from pathlib import Path
+
+import numpy as np
+from tokenizers import Tokenizer
+
+import modules.shared as shared
+from modules.callbacks import Iteratorize
+
+np.set_printoptions(precision=4, suppress=True, linewidth=200)
+
+os.environ['RWKV_JIT_ON'] = '1'
+os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
+
+from rwkv.model import RWKV
+from rwkv.utils import PIPELINE, PIPELINE_ARGS
+
+
+class RWKVModel:
+ def __init__(self):
+ pass
+
+ @classmethod
+ def from_pretrained(self, path, dtype="fp16", device="cuda"):
+ tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
+
+ if shared.args.rwkv_strategy is None:
+ model = RWKV(model=str(path), strategy=f'{device} {dtype}')
+ else:
+ model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
+ pipeline = PIPELINE(model, str(tokenizer_path))
+
+ result = self()
+ result.pipeline = pipeline
+ return result
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
+ args = PIPELINE_ARGS(
+ temperature = temperature,
+ top_p = top_p,
+ top_k = top_k,
+ alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
+ alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
+ token_ban = token_ban, # ban the generation of some tokens
+ token_stop = token_stop
+ )
+
+ return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = kwargs['context']
+ for token in generator:
+ reply += token
+ yield reply
+
+class RWKVTokenizer:
+ def __init__(self):
+ pass
+
+ @classmethod
+ def from_pretrained(self, path):
+ tokenizer_path = path / "20B_tokenizer.json"
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
+
+ result = self()
+ result.tokenizer = tokenizer
+ return result
+
+ def encode(self, prompt):
+ return self.tokenizer.encode(prompt).ids
+
+ def decode(self, ids):
+ return self.tokenizer.decode(ids)
diff --git a/modules/callbacks.py b/modules/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..faa4a5e9991e1ae711589fed61e7d1f48e28fed3
--- /dev/null
+++ b/modules/callbacks.py
@@ -0,0 +1,98 @@
+import gc
+from queue import Queue
+from threading import Thread
+
+import torch
+import transformers
+
+import modules.shared as shared
+
+# Copied from https://github.com/PygmalionAI/gradio-ui/
+class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
+
+ def __init__(self, sentinel_token_ids: torch.LongTensor,
+ starting_idx: int):
+ transformers.StoppingCriteria.__init__(self)
+ self.sentinel_token_ids = sentinel_token_ids
+ self.starting_idx = starting_idx
+
+ def __call__(self, input_ids: torch.LongTensor,
+ _scores: torch.FloatTensor) -> bool:
+ for sample in input_ids:
+ trimmed_sample = sample[self.starting_idx:]
+ # Can't unfold, output is still too tiny. Skip.
+ if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
+ continue
+
+ for window in trimmed_sample.unfold(
+ 0, self.sentinel_token_ids.shape[-1], 1):
+ if torch.all(torch.eq(self.sentinel_token_ids, window)):
+ return True
+ return False
+
+class Stream(transformers.StoppingCriteria):
+ def __init__(self, callback_func=None):
+ self.callback_func = callback_func
+
+ def __call__(self, input_ids, scores) -> bool:
+ if self.callback_func is not None:
+ self.callback_func(input_ids[0])
+ return False
+
+class Iteratorize:
+
+ """
+ Transforms a function that takes a callback
+ into a lazy iterator (generator).
+ """
+
+ def __init__(self, func, kwargs={}, callback=None):
+ self.mfunc=func
+ self.c_callback=callback
+ self.q = Queue()
+ self.sentinel = object()
+ self.kwargs = kwargs
+ self.stop_now = False
+
+ def _callback(val):
+ if self.stop_now:
+ raise ValueError
+ self.q.put(val)
+
+ def gentask():
+ try:
+ ret = self.mfunc(callback=_callback, **self.kwargs)
+ except ValueError:
+ pass
+ clear_torch_cache()
+ self.q.put(self.sentinel)
+ if self.c_callback:
+ self.c_callback(ret)
+
+ self.thread = Thread(target=gentask)
+ self.thread.start()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ obj = self.q.get(True,None)
+ if obj is self.sentinel:
+ raise StopIteration
+ else:
+ return obj
+
+ def __del__(self):
+ clear_torch_cache()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop_now = True
+ clear_torch_cache()
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
diff --git a/modules/chat.py b/modules/chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd45b879f92f366255c6f2308ccf135dd61bda1d
--- /dev/null
+++ b/modules/chat.py
@@ -0,0 +1,398 @@
+import base64
+import copy
+import io
+import json
+import re
+from datetime import datetime
+from pathlib import Path
+
+from PIL import Image
+
+import modules.extensions as extensions_module
+import modules.shared as shared
+from modules.extensions import apply_extensions
+from modules.html_generator import generate_chat_html
+from modules.text_generation import encode, generate_reply, get_max_prompt_length
+
+
+# This gets the new line characters right.
+def clean_chat_message(text):
+ text = text.replace('\n', '\n\n')
+ text = re.sub(r"\n{3,}", "\n\n", text)
+ text = text.strip()
+ return text
+
+def generate_chat_output(history, name1, name2, character):
+ if shared.args.cai_chat:
+ return generate_chat_html(history, name1, name2, character)
+ else:
+ return history
+
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
+ user_input = clean_chat_message(user_input)
+ rows = [f"{context.strip()}\n"]
+
+ if shared.soft_prompt:
+ chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
+ max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
+
+ i = len(shared.history['internal'])-1
+ while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
+ rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
+ if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
+ rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
+ i -= 1
+
+ if not impersonate:
+ rows.append(f"{name1}: {user_input}\n")
+ rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
+ limit = 3
+ else:
+ rows.append(f"{name1}:")
+ limit = 2
+
+ while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
+ rows.pop(1)
+
+ prompt = ''.join(rows)
+ return prompt
+
+def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
+ next_character_found = False
+
+ asker = name1 if not impersonate else name2
+ replier = name2 if not impersonate else name1
+
+ previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", question)]
+ idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(replier)}:", reply)]
+ idx = idx[max(len(previous_idx)-1, 0)]
+
+ if not impersonate:
+ reply = reply[idx + 1 + len(apply_extensions(f"{replier}:", "bot_prefix")):]
+ else:
+ reply = reply[idx + 1 + len(f"{replier}:"):]
+
+ if check:
+ lines = reply.split('\n')
+ reply = lines[0].strip()
+ if len(lines) > 1:
+ next_character_found = True
+ else:
+ idx = reply.find(f"\n{asker}:")
+ if idx != -1:
+ reply = reply[:idx]
+ next_character_found = True
+ reply = clean_chat_message(reply)
+
+ # If something like "\nYo" is generated just before "\nYou:"
+ # is completed, trim it
+ next_turn = f"\n{asker}:"
+ for j in range(len(next_turn)-1, 0, -1):
+ if reply[-j:] == next_turn[:j]:
+ reply = reply[:-j]
+ break
+
+ return reply, next_character_found
+
+def stop_everything_event():
+ shared.stop_everything = True
+
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
+ shared.stop_everything = False
+ just_started = True
+ eos_token = '\n' if check else None
+ name1_original = name1
+ if 'pygmalion' in shared.model_name.lower():
+ name1 = "You"
+
+ # Check if any extension wants to hijack this function call
+ visible_text = None
+ custom_generate_chat_prompt = None
+ for extension, _ in extensions_module.iterator():
+ if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
+ extension.input_hijack['state'] = False
+ text, visible_text = extension.input_hijack['value']
+ if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
+ custom_generate_chat_prompt = extension.custom_generate_chat_prompt
+
+ if visible_text is None:
+ visible_text = text
+ if shared.args.chat:
+ visible_text = visible_text.replace('\n', ' ')
+ text = apply_extensions(text, "input")
+
+ if custom_generate_chat_prompt is None:
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+ else:
+ prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+
+ # Yield *Is typing...*
+ if not regenerate:
+ yield shared.history['visible']+[[visible_text, shared.processing_message]]
+
+ # Generate
+ reply = ''
+ for i in range(chat_generation_attempts):
+ for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
+
+ # Extracting the reply
+ reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
+ visible_reply = re.sub("(||{{user}})", name1_original, reply)
+ visible_reply = apply_extensions(visible_reply, "output")
+ if shared.args.chat:
+ visible_reply = visible_reply.replace('\n', ' ')
+
+ # We need this global variable to handle the Stop event,
+ # otherwise gradio gets confused
+ if shared.stop_everything:
+ return shared.history['visible']
+ if just_started:
+ just_started = False
+ shared.history['internal'].append(['', ''])
+ shared.history['visible'].append(['', ''])
+
+ shared.history['internal'][-1] = [text, reply]
+ shared.history['visible'][-1] = [visible_text, visible_reply]
+ if not shared.args.no_stream:
+ yield shared.history['visible']
+ if next_character_found:
+ break
+
+ yield shared.history['visible']
+
+def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+ eos_token = '\n' if check else None
+
+ if 'pygmalion' in shared.model_name.lower():
+ name1 = "You"
+
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
+
+ reply = ''
+ # Yield *Is typing...*
+ yield shared.processing_message
+ for i in range(chat_generation_attempts):
+ for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
+ reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
+ yield reply
+ if next_character_found:
+ break
+ yield reply
+
+def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+ for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
+ yield generate_chat_html(_history, name1, name2, shared.character)
+
+def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+ if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
+ yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+ else:
+ last_visible = shared.history['visible'].pop()
+ last_internal = shared.history['internal'].pop()
+ # Yield '*Is typing...*'
+ yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
+ for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
+ if shared.args.cai_chat:
+ shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
+ else:
+ shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
+ yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def remove_last_message(name1, name2):
+ if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
+ last = shared.history['visible'].pop()
+ shared.history['internal'].pop()
+ else:
+ last = ['', '']
+
+ if shared.args.cai_chat:
+ return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
+ else:
+ return shared.history['visible'], last[0]
+
+def send_last_reply_to_input():
+ if len(shared.history['internal']) > 0:
+ return shared.history['internal'][-1][1]
+ else:
+ return ''
+
+def replace_last_reply(text, name1, name2):
+ if len(shared.history['visible']) > 0:
+ if shared.args.cai_chat:
+ shared.history['visible'][-1][1] = text
+ else:
+ shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
+ shared.history['internal'][-1][1] = apply_extensions(text, "input")
+
+ return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def clear_html():
+ return generate_chat_html([], "", "", shared.character)
+
+def clear_chat_log(name1, name2):
+ if shared.character != 'None':
+ found = False
+ for i in range(len(shared.history['internal'])):
+ if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
+ shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
+ shared.history['internal'] = [shared.history['internal'][i]]
+ found = True
+ break
+ if not found:
+ shared.history['visible'] = []
+ shared.history['internal'] = []
+ else:
+ shared.history['internal'] = []
+ shared.history['visible'] = []
+
+ return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def redraw_html(name1, name2):
+ return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
+
+def tokenize_dialogue(dialogue, name1, name2):
+ _history = []
+
+ dialogue = re.sub('', '', dialogue)
+ dialogue = re.sub('', '', dialogue)
+ dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
+ dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
+ idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
+ if len(idx) == 0:
+ return _history
+
+ messages = []
+ for i in range(len(idx)-1):
+ messages.append(dialogue[idx[i]:idx[i+1]].strip())
+ messages.append(dialogue[idx[-1]:].strip())
+
+ entry = ['', '']
+ for i in messages:
+ if i.startswith(f'{name1}:'):
+ entry[0] = i[len(f'{name1}:'):].strip()
+ elif i.startswith(f'{name2}:'):
+ entry[1] = i[len(f'{name2}:'):].strip()
+ if not (len(entry[0]) == 0 and len(entry[1]) == 0):
+ _history.append(entry)
+ entry = ['', '']
+
+ print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
+ for row in _history:
+ for column in row:
+ print("\n")
+ for line in column.strip().split('\n'):
+ print("| "+line+"\n")
+ print("|\n")
+ print("------------------------------")
+
+ return _history
+
+def save_history(timestamp=True):
+ prefix = '' if shared.character == 'None' else f"{shared.character}_"
+ if timestamp:
+ fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
+ else:
+ fname = f"{prefix}persistent.json"
+ if not Path('logs').exists():
+ Path('logs').mkdir()
+ with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
+ f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
+ return Path(f'logs/{fname}')
+
+def load_history(file, name1, name2):
+ file = file.decode('utf-8')
+ try:
+ j = json.loads(file)
+ if 'data' in j:
+ shared.history['internal'] = j['data']
+ if 'data_visible' in j:
+ shared.history['visible'] = j['data_visible']
+ else:
+ shared.history['visible'] = copy.deepcopy(shared.history['internal'])
+ # Compatibility with Pygmalion AI's official web UI
+ elif 'chat' in j:
+ shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
+ if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
+ shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
+ shared.history['visible'] = copy.deepcopy(shared.history['internal'])
+ shared.history['visible'][0][0] = ''
+ else:
+ shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
+ shared.history['visible'] = copy.deepcopy(shared.history['internal'])
+ except:
+ shared.history['internal'] = tokenize_dialogue(file, name1, name2)
+ shared.history['visible'] = copy.deepcopy(shared.history['internal'])
+
+def load_default_history(name1, name2):
+ if Path('logs/persistent.json').exists():
+ load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
+ else:
+ shared.history['internal'] = []
+ shared.history['visible'] = []
+
+def load_character(_character, name1, name2):
+ context = ""
+ shared.history['internal'] = []
+ shared.history['visible'] = []
+ if _character != 'None':
+ shared.character = _character
+ data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
+ name2 = data['char_name']
+ if 'char_persona' in data and data['char_persona'] != '':
+ context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
+ if 'world_scenario' in data and data['world_scenario'] != '':
+ context += f"Scenario: {data['world_scenario']}\n"
+ context = f"{context.strip()}\n\n"
+ if 'example_dialogue' in data and data['example_dialogue'] != '':
+ data['example_dialogue'] = data['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', name2)
+ data['example_dialogue'] = data['example_dialogue'].replace('', name1).replace('', name2)
+ context += f"{data['example_dialogue'].strip()}\n"
+ if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
+ shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
+ shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
+ else:
+ shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
+ shared.history['visible'] += [['', "Hello there!"]]
+ else:
+ shared.character = None
+ context = shared.settings['context_pygmalion']
+ name2 = shared.settings['name2_pygmalion']
+
+ if Path(f'logs/{shared.character}_persistent.json').exists():
+ load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
+
+ if shared.args.cai_chat:
+ return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
+ else:
+ return name2, context, shared.history['visible']
+
+def upload_character(json_file, img, tavern=False):
+ json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
+ data = json.loads(json_file)
+ outfile_name = data["char_name"]
+ i = 1
+ while Path(f'characters/{outfile_name}.json').exists():
+ outfile_name = f'{data["char_name"]}_{i:03d}'
+ i += 1
+ if tavern:
+ outfile_name = f'TavernAI-{outfile_name}'
+ with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
+ f.write(json_file)
+ if img is not None:
+ img = Image.open(io.BytesIO(img))
+ img.save(Path(f'characters/{outfile_name}.png'))
+ print(f'New character saved to "characters/{outfile_name}.json".')
+ return outfile_name
+
+def upload_tavern_character(img, name1, name2):
+ _img = Image.open(io.BytesIO(img))
+ _img.getexif()
+ decoded_string = base64.b64decode(_img.info['chara'])
+ _json = json.loads(decoded_string)
+ _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
+ return upload_character(json.dumps(_json), img, tavern=True)
+
+def upload_your_profile_picture(img):
+ img = Image.open(io.BytesIO(img))
+ img.save(Path('img_me.png'))
+ print('Profile picture saved to "img_me.png"')
diff --git a/modules/deepspeed_parameters.py b/modules/deepspeed_parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dbed437f5b5196d0b1fcbc582085319fb8d40d1
--- /dev/null
+++ b/modules/deepspeed_parameters.py
@@ -0,0 +1,75 @@
+def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
+
+ '''
+ DeepSpeed configration
+ https://huggingface.co/docs/transformers/main_classes/deepspeed
+ '''
+
+ if nvme_offload_dir:
+ ds_config = {
+ "fp16": {
+ "enabled": not ds_bf16,
+ },
+ "bf16": {
+ "enabled": ds_bf16,
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_param": {
+ "device": "nvme",
+ "nvme_path": nvme_offload_dir,
+ "pin_memory": True,
+ "buffer_count": 5,
+ "buffer_size": 1e9,
+ "max_in_cpu": 1e9
+ },
+ "overlap_comm": True,
+ "reduce_bucket_size": "auto",
+ "contiguous_gradients": True,
+ "sub_group_size": 1e8,
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": "auto",
+ "stage3_max_reuse_distance": "auto",
+ },
+ "aio": {
+ "block_size": 262144,
+ "queue_depth": 32,
+ "thread_count": 1,
+ "single_submit": False,
+ "overlap_events": True
+ },
+ "steps_per_print": 2000,
+ "train_batch_size": train_batch_size,
+ "train_micro_batch_size_per_gpu": 1,
+ "wall_clock_breakdown": False
+ }
+ else:
+ ds_config = {
+ "fp16": {
+ "enabled": not ds_bf16,
+ },
+ "bf16": {
+ "enabled": ds_bf16,
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": True
+ },
+ "overlap_comm": True,
+ "contiguous_gradients": True,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": "auto",
+ "stage3_max_reuse_distance": "auto",
+ },
+ "steps_per_print": 2000,
+ "train_batch_size": train_batch_size,
+ "train_micro_batch_size_per_gpu": 1,
+ "wall_clock_breakdown": False
+ }
+
+ return ds_config
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8de8a7bc9ebd331d65704996a764e7cc279a6e5
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,45 @@
+import extensions
+import modules.shared as shared
+
+state = {}
+available_extensions = []
+
+def load_extensions():
+ global state
+ for i, name in enumerate(shared.args.extensions):
+ if name in available_extensions:
+ print(f'Loading the extension "{name}"... ', end='')
+ exec(f"import extensions.{name}.script")
+ state[name] = [True, i]
+ print('Ok.')
+
+# This iterator returns the extensions in the order specified in the command-line
+def iterator():
+ for name in sorted(state, key=lambda x : state[x][1]):
+ if state[name][0] == True:
+ yield eval(f"extensions.{name}.script"), name
+
+# Extension functions that map string -> string
+def apply_extensions(text, typ):
+ for extension, _ in iterator():
+ if typ == "input" and hasattr(extension, "input_modifier"):
+ text = extension.input_modifier(text)
+ elif typ == "output" and hasattr(extension, "output_modifier"):
+ text = extension.output_modifier(text)
+ elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
+ text = extension.bot_prefix_modifier(text)
+ return text
+
+def create_extensions_block():
+ # Updating the default values
+ for extension, name in iterator():
+ if hasattr(extension, 'params'):
+ for param in extension.params:
+ _id = f"{name}-{param}"
+ if _id in shared.settings:
+ extension.params[param] = shared.settings[_id]
+
+ # Creating the extension ui elements
+ for extension, name in iterator():
+ if hasattr(extension, "ui"):
+ extension.ui()
diff --git a/modules/html_generator.py b/modules/html_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..162040bac68c2e987b33a02ccb12e90b51a63b2d
--- /dev/null
+++ b/modules/html_generator.py
@@ -0,0 +1,357 @@
+'''
+
+This is a library for formatting GPT-4chan and chat outputs as nice HTML.
+
+'''
+
+import os
+import re
+from pathlib import Path
+
+from PIL import Image
+
+# This is to store the paths to the thumbnails of the profile pictures
+image_cache = {}
+
+def generate_basic_html(s):
+ css = """
+ .container {
+ max-width: 600px;
+ margin-left: auto;
+ margin-right: auto;
+ background-color: rgb(31, 41, 55);
+ padding:3em;
+ }
+ .container p {
+ font-size: 16px !important;
+ color: white !important;
+ margin-bottom: 22px;
+ line-height: 1.4 !important;
+ }
+ """
+ s = '\n'.join([f'{line}
' for line in s.split('\n')])
+ s = f'{s}
'
+ return s
+
+def process_post(post, c):
+ t = post.split('\n')
+ number = t[0].split(' ')[1]
+ if len(t) > 1:
+ src = '\n'.join(t[1:])
+ else:
+ src = ''
+ src = re.sub('>', '>', src)
+ src = re.sub('(>>[0-9]*)', '\\1 ', src)
+ src = re.sub('\n', ' \n', src)
+ src = f'{src}\n'
+ src = f'Anonymous No.{number} \n{src}'
+ return src
+
+def generate_4chan_html(f):
+ css = """
+
+ #parent #container {
+ background-color: #eef2ff;
+ padding: 17px;
+ }
+ #parent #container .reply {
+ background-color: rgb(214, 218, 240);
+ border-bottom-color: rgb(183, 197, 217);
+ border-bottom-style: solid;
+ border-bottom-width: 1px;
+ border-image-outset: 0;
+ border-image-repeat: stretch;
+ border-image-slice: 100%;
+ border-image-source: none;
+ border-image-width: 1;
+ border-left-color: rgb(0, 0, 0);
+ border-left-style: none;
+ border-left-width: 0px;
+ border-right-color: rgb(183, 197, 217);
+ border-right-style: solid;
+ border-right-width: 1px;
+ border-top-color: rgb(0, 0, 0);
+ border-top-style: none;
+ border-top-width: 0px;
+ color: rgb(0, 0, 0);
+ display: table;
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ margin-bottom: 4px;
+ margin-left: 0px;
+ margin-right: 0px;
+ margin-top: 4px;
+ overflow-x: hidden;
+ overflow-y: hidden;
+ padding-bottom: 4px;
+ padding-left: 2px;
+ padding-right: 2px;
+ padding-top: 4px;
+ }
+
+ #parent #container .number {
+ color: rgb(0, 0, 0);
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ width: 342.65px;
+ margin-right: 7px;
+ }
+
+ #parent #container .op {
+ color: rgb(0, 0, 0);
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ margin-bottom: 8px;
+ margin-left: 0px;
+ margin-right: 0px;
+ margin-top: 4px;
+ overflow-x: hidden;
+ overflow-y: hidden;
+ }
+
+ #parent #container .op blockquote {
+ margin-left: 0px !important;
+ }
+
+ #parent #container .name {
+ color: rgb(17, 119, 67);
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ font-weight: 700;
+ margin-left: 7px;
+ }
+
+ #parent #container .quote {
+ color: rgb(221, 0, 0);
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ text-decoration-color: rgb(221, 0, 0);
+ text-decoration-line: underline;
+ text-decoration-style: solid;
+ text-decoration-thickness: auto;
+ }
+
+ #parent #container .greentext {
+ color: rgb(120, 153, 34);
+ font-family: arial, helvetica, sans-serif;
+ font-size: 13.3333px;
+ }
+
+ #parent #container blockquote {
+ margin: 0px !important;
+ margin-block-start: 1em;
+ margin-block-end: 1em;
+ margin-inline-start: 40px;
+ margin-inline-end: 40px;
+ margin-top: 13.33px !important;
+ margin-bottom: 13.33px !important;
+ margin-left: 40px !important;
+ margin-right: 40px !important;
+ }
+
+ #parent #container .message {
+ color: black;
+ border: none;
+ }
+ """
+
+ posts = []
+ post = ''
+ c = -2
+ for line in f.splitlines():
+ line += "\n"
+ if line == '-----\n':
+ continue
+ elif line.startswith('--- '):
+ c += 1
+ if post != '':
+ src = process_post(post, c)
+ posts.append(src)
+ post = line
+ else:
+ post += line
+ if post != '':
+ src = process_post(post, c)
+ posts.append(src)
+
+ for i in range(len(posts)):
+ if i == 0:
+ posts[i] = f'{posts[i]}
\n'
+ else:
+ posts[i] = f'{posts[i]}
\n'
+
+ output = ''
+ output += f''
+ for post in posts:
+ output += post
+ output += '
'
+ output = output.split('\n')
+ for i in range(len(output)):
+ output[i] = re.sub(r'^(>(.*?)( |))', r'\1 ', output[i])
+ output[i] = re.sub(r'^(>(.*?)( |))', r'\1 ', output[i])
+ output = '\n'.join(output)
+
+ return output
+
+def get_image_cache(path):
+ cache_folder = Path("cache")
+ if not cache_folder.exists():
+ cache_folder.mkdir()
+
+ mtime = os.stat(path).st_mtime
+ if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
+ img = Image.open(path)
+ img.thumbnail((200, 200))
+ output_file = Path(f'cache/{path.name}_cache.png')
+ img.convert('RGB').save(output_file, format='PNG')
+ image_cache[path] = [mtime, output_file.as_posix()]
+
+ return image_cache[path][1]
+
+def generate_chat_html(history, name1, name2, character):
+ css = """
+ .chat {
+ margin-left: auto;
+ margin-right: auto;
+ max-width: 800px;
+ height: 66.67vh;
+ overflow-y: auto;
+ padding-right: 20px;
+ display: flex;
+ flex-direction: column-reverse;
+ }
+
+ .message {
+ display: grid;
+ grid-template-columns: 60px 1fr;
+ padding-bottom: 25px;
+ font-size: 15px;
+ font-family: Helvetica, Arial, sans-serif;
+ line-height: 1.428571429;
+ }
+
+ .circle-you {
+ width: 50px;
+ height: 50px;
+ background-color: rgb(238, 78, 59);
+ border-radius: 50%;
+ }
+
+ .circle-bot {
+ width: 50px;
+ height: 50px;
+ background-color: rgb(59, 78, 244);
+ border-radius: 50%;
+ }
+
+ .circle-bot img, .circle-you img {
+ border-radius: 50%;
+ width: 100%;
+ height: 100%;
+ object-fit: cover;
+ }
+
+ .text {
+ }
+
+ .text p {
+ margin-top: 5px;
+ }
+
+ .username {
+ font-weight: bold;
+ }
+
+ .message-body {
+ }
+
+ .message-body img {
+ max-width: 300px;
+ max-height: 300px;
+ border-radius: 20px;
+ }
+
+ .message-body p {
+ margin-bottom: 0 !important;
+ font-size: 15px !important;
+ line-height: 1.428571429 !important;
+ }
+
+ .dark .message-body p em {
+ color: rgb(138, 138, 138) !important;
+ }
+
+ .message-body p em {
+ color: rgb(110, 110, 110) !important;
+ }
+
+ """
+
+ output = ''
+ output += f''
+ img = ''
+
+ for i in [
+ f"characters/{character}.png",
+ f"characters/{character}.jpg",
+ f"characters/{character}.jpeg",
+ "img_bot.png",
+ "img_bot.jpg",
+ "img_bot.jpeg"
+ ]:
+
+ path = Path(i)
+ if path.exists():
+ img = f'
'
+ break
+
+ img_me = ''
+ for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]:
+ path = Path(i)
+ if path.exists():
+ img_me = f'
'
+ break
+
+ for i,_row in enumerate(history[::-1]):
+ row = _row.copy()
+ row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"
\2 ", row[0])
+ row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"
\2 ", row[1])
+ row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"
\2 ", row[0])
+ row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"
\2 ", row[1])
+ p = '\n'.join([f"
{x}
" for x in row[1].split('\n')])
+ output += f"""
+
+
+ {img}
+
+
+
+ {name2}
+
+
+ {p}
+
+
+
+ """
+
+ if not (i == len(history)-1 and len(row[0]) == 0):
+ p = '\n'.join([f"
{x}
" for x in row[0].split('\n')])
+ output += f"""
+
+
+ {img_me}
+
+
+
+ {name1}
+
+
+ {p}
+
+
+
+ """
+
+ output += "
"
+ return output
diff --git a/modules/models.py b/modules/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4bb11fd3f7292657b008ab644b5be121d9980e5
--- /dev/null
+++ b/modules/models.py
@@ -0,0 +1,168 @@
+import json
+import os
+import time
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+import modules.shared as shared
+
+transformers.logging.set_verbosity_error()
+
+local_rank = None
+
+if shared.args.flexgen:
+ from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
+ Policy, str2bool)
+
+if shared.args.deepspeed:
+ import deepspeed
+ from transformers.deepspeed import (HfDeepSpeedConfig,
+ is_deepspeed_zero3_enabled)
+
+ from modules.deepspeed_parameters import generate_ds_config
+
+ # Distributed setup
+ local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ torch.cuda.set_device(local_rank)
+ deepspeed.init_distributed()
+ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+
+
+def load_model(model_name):
+ print(f"Loading {model_name}...")
+ t0 = time.time()
+
+ shared.is_RWKV = model_name.lower().startswith('rwkv-')
+
+ # Default settings
+ if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
+ if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
+ model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
+ else:
+ model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
+
+ # FlexGen
+ elif shared.args.flexgen:
+ # Initialize environment
+ env = ExecutionEnv.create(shared.args.disk_cache_dir)
+
+ # Offloading policy
+ policy = Policy(1, 1,
+ shared.args.percent[0], shared.args.percent[1],
+ shared.args.percent[2], shared.args.percent[3],
+ shared.args.percent[4], shared.args.percent[5],
+ overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
+ cpu_cache_compute=False, attn_sparsity=1.0,
+ compress_weight=shared.args.compress_weight,
+ comp_weight_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=0, symmetric=False),
+ compress_cache=False,
+ comp_cache_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=2, symmetric=False))
+
+ model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
+
+ # DeepSpeed ZeRO-3
+ elif shared.args.deepspeed:
+ model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
+ model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
+ model.module.eval() # Inference
+ print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
+
+ # RMKV model (not on HuggingFace)
+ elif shared.is_RWKV:
+ from modules.RWKV import RWKVModel, RWKVTokenizer
+
+ model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+ tokenizer = RWKVTokenizer.from_pretrained(Path('models'))
+
+ return model, tokenizer
+
+ # Quantized model
+ elif shared.args.gptq_bits > 0:
+ from modules.GPTQ_loader import load_quantized
+
+ model = load_quantized(model_name)
+
+ # Custom
+ else:
+ command = "AutoModelForCausalLM.from_pretrained"
+ params = ["low_cpu_mem_usage=True"]
+ if not shared.args.cpu and not torch.cuda.is_available():
+ print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
+ shared.args.cpu = True
+
+ if shared.args.cpu:
+ params.append("low_cpu_mem_usage=True")
+ params.append("torch_dtype=torch.float32")
+ else:
+ params.append("device_map='auto'")
+ params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
+
+ if shared.args.gpu_memory:
+ memory_map = shared.args.gpu_memory
+ max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
+ for i in range(1, len(memory_map)):
+ max_memory += (f", {i}: '{memory_map[i]}GiB'")
+ max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
+ params.append(max_memory)
+ elif not shared.args.load_in_8bit:
+ total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
+ suggestion = round((total_mem-1000)/1000)*1000
+ if total_mem-suggestion < 800:
+ suggestion -= 1000
+ suggestion = int(round(suggestion/1000))
+ print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
+ params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
+ if shared.args.disk:
+ params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
+
+ command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
+ model = eval(command)
+
+ # Loading the tokenizer
+ if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
+ tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/"))
+ tokenizer.truncation_side = 'left'
+
+ print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
+ return model, tokenizer
+
+def load_soft_prompt(name):
+ if name == 'None':
+ shared.soft_prompt = False
+ shared.soft_prompt_tensor = None
+ else:
+ with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
+ zf.extract('tensor.npy')
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ print(f"\nLoading the softprompt \"{name}\".")
+ for field in j:
+ if field != 'name':
+ if type(j[field]) is list:
+ print(f"{field}: {', '.join(j[field])}")
+ else:
+ print(f"{field}: {j[field]}")
+ print()
+ tensor = np.load('tensor.npy')
+ Path('tensor.npy').unlink()
+ Path('meta.json').unlink()
+ tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
+ tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
+
+ shared.soft_prompt = True
+ shared.soft_prompt_tensor = tensor
+
+ return name
diff --git a/modules/shared.py b/modules/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2eb50b7f586e5c562bf2e7c75429c91f21ec6c
--- /dev/null
+++ b/modules/shared.py
@@ -0,0 +1,103 @@
+import argparse
+
+model = None
+tokenizer = None
+model_name = ""
+soft_prompt_tensor = None
+soft_prompt = False
+is_RWKV = False
+
+# Chat variables
+history = {'internal': [], 'visible': []}
+character = 'None'
+stop_everything = False
+processing_message = '*Is typing...*'
+
+# UI elements (buttons, sliders, HTML, etc)
+gradio = {}
+
+# Generation input parameters
+input_params = []
+
+settings = {
+ 'max_new_tokens': 200,
+ 'max_new_tokens_min': 1,
+ 'max_new_tokens_max': 2000,
+ 'name1': 'Person 1',
+ 'name2': 'Person 2',
+ 'context': 'This is a conversation between two people.',
+ 'stop_at_newline': True,
+ 'chat_prompt_size': 2048,
+ 'chat_prompt_size_min': 0,
+ 'chat_prompt_size_max': 2048,
+ 'chat_generation_attempts': 1,
+ 'chat_generation_attempts_min': 1,
+ 'chat_generation_attempts_max': 5,
+ 'name1_pygmalion': 'You',
+ 'name2_pygmalion': 'Kawaii',
+ 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n",
+ 'stop_at_newline_pygmalion': False,
+ 'default_extensions': [],
+ 'chat_default_extensions': ["gallery"],
+ 'presets': {
+ 'default': 'NovelAI-Sphinx Moth',
+ 'pygmalion-*': 'Pygmalion',
+ 'RWKV-*': 'Naive',
+ },
+ 'prompts': {
+ 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
+ '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
+ '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
+ 'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
+ }
+}
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
+parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
+parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
+parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
+parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
+parser.add_argument('--gptq-model-type', type=str, help='Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
+parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
+parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
+parser.add_argument('--gpu-memory', type=int, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs.')
+parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
+parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
+parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
+parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
+parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
+parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
+parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
+parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
+parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
+parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
+parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
+parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
+parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+args = parser.parse_args()
+
+# Provisional, this will be deleted later
+if args.load_in_4bit:
+ print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n")
+ args.gptq_bits = 4
diff --git a/modules/text_generation.py b/modules/text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64481b24ec4542e55de1605a6181f97d9a50de9
--- /dev/null
+++ b/modules/text_generation.py
@@ -0,0 +1,238 @@
+import gc
+import re
+import time
+
+import numpy as np
+import torch
+import transformers
+
+import modules.shared as shared
+from modules.callbacks import (Iteratorize, Stream,
+ _SentinelTokenStoppingCriteria)
+from modules.extensions import apply_extensions
+from modules.html_generator import generate_4chan_html, generate_basic_html
+from modules.models import local_rank
+
+
+def get_max_prompt_length(tokens):
+ max_length = 2048-tokens
+ if shared.soft_prompt:
+ max_length -= shared.soft_prompt_tensor.shape[1]
+ return max_length
+
+def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
+ if shared.is_RWKV:
+ input_ids = shared.tokenizer.encode(str(prompt))
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
+ return input_ids
+ else:
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+ if shared.args.cpu:
+ return input_ids
+ elif shared.args.flexgen:
+ return input_ids.numpy()
+ elif shared.args.deepspeed:
+ return input_ids.to(device=local_rank)
+ else:
+ return input_ids.cuda()
+
+def decode(output_ids):
+ # Open Assistant relies on special tokens like <|endoftext|>
+ if re.match('oasst-*', shared.model_name.lower()):
+ return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
+ else:
+ reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
+ reply = reply.replace(r'<|endoftext|>', '')
+ return reply
+
+def generate_softprompt_input_tensors(input_ids):
+ inputs_embeds = shared.model.transformer.wte(input_ids)
+ inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
+ filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
+ #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
+ return inputs_embeds, filler_input_ids
+
+# Removes empty replies from gpt4chan outputs
+def fix_gpt4chan(s):
+ for i in range(10):
+ s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
+ s = re.sub("--- [0-9]*\n *\n---", "---", s)
+ s = re.sub("--- [0-9]*\n\n\n---", "---", s)
+ return s
+
+# Fix the LaTeX equations in galactica
+def fix_galactica(s):
+ s = s.replace(r'\[', r'$')
+ s = s.replace(r'\]', r'$')
+ s = s.replace(r'\(', r'$')
+ s = s.replace(r'\)', r'$')
+ s = s.replace(r'$$', r'$')
+ s = re.sub(r'\n', r'\n\n', s)
+ s = re.sub(r"\n{3,}", "\n\n", s)
+ return s
+
+def formatted_outputs(reply, model_name):
+ if not (shared.args.chat or shared.args.cai_chat):
+ if model_name.lower().startswith('galactica'):
+ reply = fix_galactica(reply)
+ return reply, reply, generate_basic_html(reply)
+ elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
+ reply = fix_gpt4chan(reply)
+ return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+ else:
+ return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
+ else:
+ return reply
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
+
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+ clear_torch_cache()
+ t0 = time.time()
+
+ # These models are not part of Hugging Face, so we handle them
+ # separately and terminate the function call earlier
+ if shared.is_RWKV:
+ try:
+ if shared.args.no_stream:
+ reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
+ yield formatted_outputs(reply, shared.model_name)
+ else:
+ yield formatted_outputs(question, shared.model_name)
+ # RWKV has proper streaming, which is very nice.
+ # No need to generate 8 tokens at a time.
+ for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+ yield formatted_outputs(reply, shared.model_name)
+ finally:
+ t1 = time.time()
+ output = encode(reply)[0]
+ input_ids = encode(question)
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+ return
+
+ original_question = question
+ if not (shared.args.chat or shared.args.cai_chat):
+ question = apply_extensions(question, "input")
+ if shared.args.verbose:
+ print(f"\n\n{question}\n--------------------\n")
+
+ input_ids = encode(question, max_new_tokens)
+ original_input_ids = input_ids
+ output = input_ids[0]
+ cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ if eos_token is not None:
+ eos_token_ids.append(int(encode(eos_token)[0][-1]))
+ stopping_criteria_list = transformers.StoppingCriteriaList()
+ if stopping_string is not None:
+ # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
+ t = encode(stopping_string, 0, add_special_tokens=False)
+ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
+
+ if not shared.args.flexgen:
+ generate_params = [
+ f"max_new_tokens=max_new_tokens",
+ f"eos_token_id={eos_token_ids}",
+ f"stopping_criteria=stopping_criteria_list",
+ f"do_sample={do_sample}",
+ f"temperature={temperature}",
+ f"top_p={top_p}",
+ f"typical_p={typical_p}",
+ f"repetition_penalty={repetition_penalty}",
+ f"top_k={top_k}",
+ f"min_length={min_length if shared.args.no_stream else 0}",
+ f"no_repeat_ngram_size={no_repeat_ngram_size}",
+ f"num_beams={num_beams}",
+ f"penalty_alpha={penalty_alpha}",
+ f"length_penalty={length_penalty}",
+ f"early_stopping={early_stopping}",
+ ]
+ else:
+ generate_params = [
+ f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
+ f"do_sample={do_sample}",
+ f"temperature={temperature}",
+ f"stop={eos_token_ids[-1]}",
+ ]
+ if shared.args.deepspeed:
+ generate_params.append("synced_gpus=True")
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+ generate_params.insert(0, "inputs_embeds=inputs_embeds")
+ generate_params.insert(0, "inputs=filler_input_ids")
+ else:
+ generate_params.insert(0, "inputs=input_ids")
+
+ try:
+ # Generate the entire reply at once.
+ if shared.args.no_stream:
+ with torch.no_grad():
+ output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ reply = decode(output)
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ elif not shared.args.flexgen:
+
+ def generate_with_callback(callback=None, **kwargs):
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ clear_torch_cache()
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
+
+ yield formatted_outputs(original_question, shared.model_name)
+ with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+ for output in generator:
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+ reply = decode(output)
+
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+
+ if output[-1] in eos_token_ids:
+ break
+ yield formatted_outputs(reply, shared.model_name)
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(max_new_tokens//8+1):
+ clear_torch_cache()
+ with torch.no_grad():
+ output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+ reply = decode(output)
+
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+
+ if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
+ break
+ yield formatted_outputs(reply, shared.model_name)
+
+ input_ids = np.reshape(output, (1, output.shape[0]))
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ finally:
+ t1 = time.time()
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
+ return
diff --git a/modules/ui.py b/modules/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb193e35c11b2a3d474ea89e7567206a3343395a
--- /dev/null
+++ b/modules/ui.py
@@ -0,0 +1,92 @@
+import gradio as gr
+
+refresh_symbol = '\U0001f504' # 🔄
+
+css = """
+.tabs.svelte-710i53 {
+ margin-top: 0
+}
+.py-6 {
+ padding-top: 2.5rem
+}
+.dark #refresh-button {
+ background-color: #ffffff1f;
+}
+#refresh-button {
+ flex: none;
+ margin: 0;
+ padding: 0;
+ min-width: 50px;
+ border: none;
+ box-shadow: none;
+ border-radius: 10px;
+ background-color: #0000000d;
+}
+#download-label, #upload-label {
+ min-height: 0
+}
+#accordion {
+}
+.dark svg {
+ fill: white;
+}
+svg {
+ display: unset !important;
+ vertical-align: middle !important;
+ margin: 5px;
+}
+ol li p, ul li p {
+ display: inline-block;
+}
+"""
+
+chat_css = """
+.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
+ height: 66.67vh
+}
+.gradio-container {
+ max-width: 800px !important;
+ margin-left: auto !important;
+ margin-right: auto !important;
+}
+.w-screen {
+ width: unset
+}
+div.svelte-362y77>*, div.svelte-362y77>.form>* {
+ flex-wrap: nowrap
+}
+/* fixes the API documentation in chat mode */
+.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
+ display: grid;
+}
+.pending.svelte-1ed2p3z {
+ opacity: 1;
+}
+"""
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
diff --git a/presets/Contrastive Search.txt b/presets/Contrastive Search.txt
new file mode 100644
index 0000000000000000000000000000000000000000..832bc9caf9b744d9d9c728f88d887f012a56ba3e
--- /dev/null
+++ b/presets/Contrastive Search.txt
@@ -0,0 +1,3 @@
+do_sample=False
+penalty_alpha=0.6
+top_k=4
diff --git a/presets/Debug-deterministic.txt b/presets/Debug-deterministic.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6673b71c8164effc401a486055b7f9a021b2acfb
--- /dev/null
+++ b/presets/Debug-deterministic.txt
@@ -0,0 +1 @@
+do_sample=False
diff --git a/presets/Default.txt b/presets/Default.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9f0983ec7f67e44ac6a383bc13636eec8ad01c78
--- /dev/null
+++ b/presets/Default.txt
@@ -0,0 +1,12 @@
+do_sample=True
+temperature=1
+top_p=1
+typical_p=1
+repetition_penalty=1
+top_k=50
+num_beams=1
+penalty_alpha=0
+min_length=0
+length_penalty=1
+no_repeat_ngram_size=0
+early_stopping=False
diff --git a/presets/Individual Today.txt b/presets/Individual Today.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f40b879cefc3d3e7914fc03f0f2322758c51cc05
--- /dev/null
+++ b/presets/Individual Today.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.9
+top_k=50
+temperature=1.39
+repetition_penalty=1.08
+typical_p=0.2
diff --git a/presets/Kobold-Godlike.txt b/presets/Kobold-Godlike.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ba5b794b6d0130a1fa1d918bda9a276f7d23367
--- /dev/null
+++ b/presets/Kobold-Godlike.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.5
+top_k=0
+temperature=0.7
+repetition_penalty=1.1
+typical_p=0.19
diff --git a/presets/Kobold-Liminal Drift.txt b/presets/Kobold-Liminal Drift.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be4dd3bd7a70af2d4eb6c847bed6bedee5379dce
--- /dev/null
+++ b/presets/Kobold-Liminal Drift.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.66
+repetition_penalty=1.1
+typical_p=0.6
diff --git a/presets/Naive.txt b/presets/Naive.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa8c058224c533f4084e230f6bbf77b63d5e81ea
--- /dev/null
+++ b/presets/Naive.txt
@@ -0,0 +1,4 @@
+do_sample=True
+temperature=0.7
+top_p=0.85
+top_k=50
diff --git a/presets/NovelAI-Best Guess.txt b/presets/NovelAI-Best Guess.txt
new file mode 100644
index 0000000000000000000000000000000000000000..db3fa75b2a11d7e29b108177f9894e82d1e52126
--- /dev/null
+++ b/presets/NovelAI-Best Guess.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.9
+top_k=100
+temperature=0.8
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Decadence.txt b/presets/NovelAI-Decadence.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3109f3e3f3a021810d171a0b98f615766b57e4b
--- /dev/null
+++ b/presets/NovelAI-Decadence.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=2
+repetition_penalty=1
+typical_p=0.97
diff --git a/presets/NovelAI-Genesis.txt b/presets/NovelAI-Genesis.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cc7376b3b981a260448a65cd3c00c7b3904308e2
--- /dev/null
+++ b/presets/NovelAI-Genesis.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.98
+top_k=0
+temperature=0.63
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/presets/NovelAI-Lycaenidae.txt b/presets/NovelAI-Lycaenidae.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0134569cef76bc0de6b3dc7885d94d9d9afdfd62
--- /dev/null
+++ b/presets/NovelAI-Lycaenidae.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.85
+top_k=12
+temperature=2
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Ouroboros.txt b/presets/NovelAI-Ouroboros.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e944b54e78e1f63bd4bb6f56a717e0fec751c6b
--- /dev/null
+++ b/presets/NovelAI-Ouroboros.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=1.07
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/presets/NovelAI-Pleasing Results.txt b/presets/NovelAI-Pleasing Results.txt
new file mode 100644
index 0000000000000000000000000000000000000000..330114a25db6d194dbc8689bf5476a81f649cf64
--- /dev/null
+++ b/presets/NovelAI-Pleasing Results.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.44
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Sphinx Moth.txt b/presets/NovelAI-Sphinx Moth.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bace1e24b5dcc64fdde99097930f41a991e91b8e
--- /dev/null
+++ b/presets/NovelAI-Sphinx Moth.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.18
+top_k=30
+temperature=2.0
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Storywriter.txt b/presets/NovelAI-Storywriter.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2df5f8181458c642ed4691925ade3d542de5391c
--- /dev/null
+++ b/presets/NovelAI-Storywriter.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.73
+top_k=0
+temperature=0.72
+repetition_penalty=1.1
+typical_p=1.0
diff --git a/presets/Pygmalion.txt b/presets/Pygmalion.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f8b2ca55304ce8243e26bd28ebc757e40354a0e9
--- /dev/null
+++ b/presets/Pygmalion.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.9
+top_k=0
+temperature=0.5
+repetition_penalty=1.1
+typical_p=1.0
diff --git a/presets/Verbose (Beam Search).txt b/presets/Verbose (Beam Search).txt
new file mode 100644
index 0000000000000000000000000000000000000000..a3be1b94f27e31e1d0e762a15fd0300abb32e460
--- /dev/null
+++ b/presets/Verbose (Beam Search).txt
@@ -0,0 +1,9 @@
+num_beams=10
+min_length=200
+length_penalty =1.4
+no_repeat_ngram_size=2
+early_stopping=True
+temperature=0.7
+top_k=150
+top_p=0.92
+repetition_penalty=4.5
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..30d93ffdf83ee2dcd2ae673e45dc18f84f7bf142
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+torch
+torchvision
+torchaudio
+transformers
+accelerate==0.17.1
+bitsandbytes==0.37.1
+flexgen==0.1.7
+gradio==3.18.0
+numpy
+requests
+rwkv==0.4.2
+safetensors==0.3.0
+sentencepiece
+tqdm
+git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c966a2f5691c6444c3329365c39e78b74fdbf95
--- /dev/null
+++ b/run.py
@@ -0,0 +1,4 @@
+import os
+os.system('python download-model.py PygmalionAI/pygmalion-350m --branch main')
+# os.system('python download-model.py waifu-workshop/pygmalion-6b --branch original-sharded')
+os.system('python server.py --cpu --chat --model pygmalion-350m --no-stream --auto-devices')
\ No newline at end of file
diff --git a/server.py b/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a17f26287d94e9187a4f315fe9fb7d2dc6ec171
--- /dev/null
+++ b/server.py
@@ -0,0 +1,382 @@
+import gc
+import io
+import json
+import re
+import sys
+import time
+import zipfile
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+import modules.chat as chat
+import modules.extensions as extensions_module
+import modules.shared as shared
+import modules.ui as ui
+from modules.html_generator import generate_chat_html
+from modules.models import load_model, load_soft_prompt
+from modules.text_generation import generate_reply
+
+# Loading custom settings
+settings_file = None
+if shared.args.settings is not None and Path(shared.args.settings).exists():
+ settings_file = Path(shared.args.settings)
+elif Path('settings.json').exists():
+ settings_file = Path('settings.json')
+if settings_file is not None:
+ print(f"Loading settings from {settings_file}...")
+ new_settings = json.loads(open(settings_file, 'r').read())
+ for item in new_settings:
+ shared.settings[item] = new_settings[item]
+
+def get_available_models():
+ if shared.args.flexgen:
+ return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
+ else:
+ return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
+
+def get_available_presets():
+ return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
+
+def get_available_characters():
+ return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
+
+def get_available_extensions():
+ return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+
+def get_available_softprompts():
+ return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
+
+def load_model_wrapper(selected_model):
+ if selected_model != shared.model_name:
+ shared.model_name = selected_model
+ shared.model = shared.tokenizer = None
+ if not shared.args.cpu:
+ gc.collect()
+ torch.cuda.empty_cache()
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
+ return selected_model
+
+def load_preset_values(preset_menu, return_dict=False):
+ generate_params = {
+ 'do_sample': True,
+ 'temperature': 1,
+ 'top_p': 1,
+ 'typical_p': 1,
+ 'repetition_penalty': 1,
+ 'top_k': 50,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'min_length': 0,
+ 'length_penalty': 1,
+ 'no_repeat_ngram_size': 0,
+ 'early_stopping': False,
+ }
+ with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
+ preset = infile.read()
+ for i in preset.splitlines():
+ i = i.rstrip(',').strip().split('=')
+ if len(i) == 2 and i[0].strip() != 'tokens':
+ generate_params[i[0].strip()] = eval(i[1].strip())
+
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
+
+ if return_dict:
+ return generate_params
+ else:
+ return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+
+def upload_soft_prompt(file):
+ with zipfile.ZipFile(io.BytesIO(file)) as zf:
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ name = j['name']
+ Path('meta.json').unlink()
+
+ with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
+ f.write(file)
+
+ return name
+
+def create_settings_menus(default_preset):
+ generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+
+ with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
+ shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
+ shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
+ with gr.Column():
+ shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
+ shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
+ shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
+ shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
+
+ gr.Markdown('Contrastive search:')
+ shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
+
+ gr.Markdown('Beam search (uses a lot of VRAM):')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
+ with gr.Column():
+ shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
+ shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+
+ with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
+ with gr.Row():
+ shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
+ ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
+
+ gr.Markdown('Upload a soft prompt (.zip format):')
+ with gr.Row():
+ shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
+
+ shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
+ shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
+ shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
+
+available_models = get_available_models()
+available_presets = get_available_presets()
+available_characters = get_available_characters()
+available_softprompts = get_available_softprompts()
+
+# Default extensions
+extensions_module.available_extensions = get_available_extensions()
+if shared.args.chat or shared.args.cai_chat:
+ for extension in shared.settings['chat_default_extensions']:
+ shared.args.extensions = shared.args.extensions or []
+ if extension not in shared.args.extensions:
+ shared.args.extensions.append(extension)
+else:
+ for extension in shared.settings['default_extensions']:
+ shared.args.extensions = shared.args.extensions or []
+ if extension not in shared.args.extensions:
+ shared.args.extensions.append(extension)
+if shared.args.extensions is not None and len(shared.args.extensions) > 0:
+ extensions_module.load_extensions()
+
+# Default model
+if shared.args.model is not None:
+ shared.model_name = shared.args.model
+else:
+ if len(available_models) == 0:
+ print('No models are available! Please download at least one.')
+ sys.exit(0)
+ elif len(available_models) == 1:
+ i = 0
+ else:
+ print('The following models are available:\n')
+ for i, model in enumerate(available_models):
+ print(f'{i+1}. {model}')
+ print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
+ i = int(input())-1
+ print()
+ shared.model_name = available_models[i]
+shared.model, shared.tokenizer = load_model(shared.model_name)
+
+# Default UI settings
+gen_events = []
+default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+title ='Text generation web UI'
+description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
+suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
+
+if shared.args.chat or shared.args.cai_chat:
+ with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+ gr.HTML('''Original github repo
+For faster inference without waiting in queue, you may duplicate the space.
+(👇 Scroll down to see the interface 👀)''')
+ if shared.args.cai_chat:
+ shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
+ else:
+ shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
+ shared.gradio['textbox'] = gr.Textbox(label='Input')
+ with gr.Row():
+ shared.gradio['Stop'] = gr.Button('Stop')
+ shared.gradio['Generate'] = gr.Button('Generate')
+ with gr.Row():
+ shared.gradio['Impersonate'] = gr.Button('Impersonate')
+ shared.gradio['Regenerate'] = gr.Button('Regenerate')
+ with gr.Row():
+ shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+ shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+ shared.gradio['Remove last'] = gr.Button('Remove last')
+
+ shared.gradio['Clear history'] = gr.Button('Clear history')
+ shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
+ shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+ with gr.Tab('Chat settings'):
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
+ shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
+ with gr.Row():
+ shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
+
+ with gr.Row():
+ shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
+ with gr.Row():
+ with gr.Tab('Chat history'):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('Upload')
+ shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
+ with gr.Column():
+ gr.Markdown('Download')
+ shared.gradio['download'] = gr.File()
+ shared.gradio['download_button'] = gr.Button(value='Click me')
+ with gr.Tab('Upload character'):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('1. Select the JSON file')
+ shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
+ with gr.Column():
+ gr.Markdown('2. Select your character\'s profile picture (optional)')
+ shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
+ shared.gradio['Upload character'] = gr.Button(value='Submit')
+ with gr.Tab('Upload your profile picture'):
+ shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
+ with gr.Tab('Upload TavernAI Character Card'):
+ shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
+
+ with gr.Tab('Generation settings'):
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+ with gr.Column():
+ shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
+ shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
+ create_settings_menus(default_preset)
+
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+ if shared.args.extensions is not None:
+ with gr.Tab('Extensions'):
+ extensions_module.create_extensions_block()
+
+ function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
+
+ gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
+ shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
+
+ shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
+ shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
+
+ # Clear history with confirmation
+ clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
+ shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
+ shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+ shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+ shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+
+ shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
+ shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
+ shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
+
+ # Clearing stuff and saving the history
+ for i in ['Generate', 'Regenerate', 'Replace last reply']:
+ shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
+ shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+ shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+ shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
+ shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+
+ shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
+ shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
+ shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
+ shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
+
+ reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
+ reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
+ shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+ shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+ shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
+
+ shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
+ shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
+
+elif shared.args.notebook:
+ with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+ gr.Markdown(description)
+ with gr.Tab('Raw'):
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
+ with gr.Tab('Markdown'):
+ shared.gradio['markdown'] = gr.Markdown()
+ with gr.Tab('HTML'):
+ shared.gradio['html'] = gr.HTML()
+
+ shared.gradio['Generate'] = gr.Button('Generate')
+ shared.gradio['Stop'] = gr.Button('Stop')
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+
+ create_settings_menus(default_preset)
+ if shared.args.extensions is not None:
+ extensions_module.create_extensions_block()
+
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+ output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
+ shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+
+else:
+ with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+ gr.Markdown(description)
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+ shared.gradio['Generate'] = gr.Button('Generate')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['Continue'] = gr.Button('Continue')
+ with gr.Column():
+ shared.gradio['Stop'] = gr.Button('Stop')
+
+ create_settings_menus(default_preset)
+ if shared.args.extensions is not None:
+ extensions_module.create_extensions_block()
+
+ with gr.Column():
+ with gr.Tab('Raw'):
+ shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
+ with gr.Tab('Markdown'):
+ shared.gradio['markdown'] = gr.Markdown()
+ with gr.Tab('HTML'):
+ shared.gradio['html'] = gr.HTML()
+
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+ output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
+ shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+
+shared.gradio['interface'].queue()
+if shared.args.listen:
+ shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
+else:
+ shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
+
+# I think that I will need this later
+while True:
+ time.sleep(0.5)
diff --git a/settings-template.json b/settings-template.json
new file mode 100644
index 0000000000000000000000000000000000000000..9da4397012ccf6821dffb048d741071cf97fff6f
--- /dev/null
+++ b/settings-template.json
@@ -0,0 +1,35 @@
+{
+ "max_new_tokens": 200,
+ "max_new_tokens_min": 1,
+ "max_new_tokens_max": 2000,
+ "name1": "Person 1",
+ "name2": "Person 2",
+ "context": "This is a conversation between two people.",
+ "stop_at_newline": true,
+ "chat_prompt_size": 2048,
+ "chat_prompt_size_min": 0,
+ "chat_prompt_size_max": 2048,
+ "chat_generation_attempts": 1,
+ "chat_generation_attempts_min": 1,
+ "chat_generation_attempts_max": 5,
+ "name1_pygmalion": "You",
+ "name2_pygmalion": "Kawaii",
+ "context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n",
+ "stop_at_newline_pygmalion": false,
+ "default_extensions": [],
+ "chat_default_extensions": [
+ "gallery"
+ ],
+ "presets": {
+ "default": "NovelAI-Sphinx Moth",
+ "pygmalion-*": "Pygmalion",
+ "RWKV-*": "Naive",
+ "(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
+ },
+ "prompts": {
+ "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
+ "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
+ "(rosey|chip|joi)_.*_instruct.*": "User: \n",
+ "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
+ }
+}
diff --git a/softprompts/place-your-softprompts-here.txt b/softprompts/place-your-softprompts-here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391