Spaces:
Runtime error
Runtime error
File size: 4,548 Bytes
dda1539 34cbb1e 0633ac3 34cbb1e f9e22c8 5824320 f9e22c8 dda1539 79a5328 dda1539 37a92ca dda1539 79a5328 dda1539 09cab71 dda1539 8f4ffb5 b01dd78 cfcc2e7 dda1539 b01dd78 dda1539 5a6e4f1 dda1539 73e08ae f130f93 392e26d dda1539 b01dd78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import json
import os
import spaces
import torch
from dotenv import load_dotenv
from huggingface_hub import login, snapshot_download
from superposed.llama.superposed_generation import SuperposedLlama
from superposed.llama.tokenizer import Tokenizer
from superposed.ngrams.ngram_models import make_models
# Set torch dist variables
os.environ['RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MASTER_PORT'] = "12193"
os.environ['MASTER_ADDR'] = "127.0.0.1"
def load_models():
model = SuperposedLlama.build(ckpt_dir=weight_path,
tokenizer_path=f'{weight_path}/tokenizer.model',
max_seq_len=100,
max_batch_size=32,
device="cuda",
model_parallel_size=1)
return model
# load_dotenv()
# print(os.getenv("HF_ACCESS_TOKEN"))
login(os.getenv("HF_ACCESS_TOKEN"))
if not os.path.exists("./weights/"):
os.mkdir("./weights/")
snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/")
weight_path = "./weights/"
# Load params
param_file = "params/p15_d3_ngram4_mixed.json"
with open(param_file, "r") as f:
params = json.load(f)
alpha = params["alpha"]
temp = params["temp"]
n_drafts = params["n_drafts"]
prompt_len = params["prompt_len"]
n_token_sample = params["n_token_sample"]
i_weights = params["i_weights"]
i_length = params["i_length"]
# Load main model
model = load_models()
tokenizer = Tokenizer(f'{weight_path}/tokenizer.model')
# Create ngram models
ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=False, sixgram=False, sevengram=False)
def decode(tokenizer, encoding):
"""
Args:
tokenizer (Any): Tokenizer
encoding (torch.Tensor): Encoding
Returns:
decoding (str)
"""
eos_locs = (encoding == tokenizer.eos_id).nonzero()
if len(eos_locs > 0):
encoding = encoding[:eos_locs[0]]
return tokenizer.decode(encoding.to(torch.int32).tolist())
@spaces.GPU
def update_options(input, num_tokens):
tokenized_prompts = tokenizer.encode([input], True, False)
print("Processed prompt")
model.model.to("cuda")
model.model.device = "cuda"
alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts,
smoothing="geom",
max_gen_len=num_tokens,
n_token_sample=n_token_sample,
alpha=alpha,
temp=temp,
n_drafts=n_drafts,
i_weights=i_weights,
i_length=i_length,
ngrams=ngrams,
get_time=False,
penalty=200)
print("Generated")
gens = alive_gens[0].reshape(n_drafts, -1)
return decode(tokenizer, gens[0])[len(input):], decode(tokenizer, gens[1])[len(input):], decode(tokenizer, gens[2])[len(input):]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Superposed Decoding
Start typing below to see suggestions.\n
Disclaimer: This demo only uses $n=\{2, 3, 4\}$ $n$-grams as opposed to $n=\{2, 3, 4, 5, 6}\$ in the paper. In addition, there may be significant latency at times because a GPU must be re-aquired after every change.
\n
Paper: [https://arxiv.org/abs/2405.18400](https://arxiv.org/abs/2405.18400)\n
Code: [https://github.com/RAIVNLab/SuperposedDecoding](https://github.com/RAIVNLab/SuperposedDecoding)
""")
slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10)
inp = gr.Textbox(placeholder="Type anything!", lines=3)
option1 = gr.Button(value="Option 1")
option2 = gr.Button(value="Option 2")
option3 = gr.Button(value="Option 3")
inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3])
# Button updates
@option1.click(inputs=[inp, option1], outputs=inp)
def option1_click(curr, txt):
return curr + txt
@option2.click(inputs=[inp, option2], outputs=inp)
def option2_click(curr, txt):
return curr + txt
@option3.click(inputs=[inp, option3], outputs=inp)
def option3_click(curr, txt):
return curr + txt
if __name__ == "__main__":
demo.launch(share=True) |