|
import re |
|
from unittest import result |
|
import string |
|
import streamlit as st |
|
import torch |
|
from torch.nn import functional as F |
|
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering, |
|
AutoModelForSeq2SeqLM, |
|
AutoModelForSequenceClassification, AutoTokenizer, |
|
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList, |
|
pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint) |
|
import ast |
|
|
|
|
|
|
|
|
|
class ModifyLogitsProcessor(LogitsProcessor): |
|
|
|
def __init__(self, tokenizer, chars_to_modify, filter_mode=True): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.filter_mode = filter_mode |
|
self.chars_to_modify = chars_to_modify |
|
|
|
|
|
self.tokens_to_modify = {} |
|
for char, factor in chars_to_modify.items(): |
|
mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token] |
|
self.tokens_to_modify[char] = mod_tokens |
|
|
|
def __call__(self, input_ids, scores): |
|
for char, tokens in self.tokens_to_modify.items(): |
|
if self.filter_mode: |
|
scores[:, tokens] = -float('inf') |
|
else: |
|
|
|
factor = self.chars_to_modify[char] |
|
scores[:, tokens] += factor |
|
return scores |
|
|
|
|
|
st.set_page_config(page_title="Gadsby") |
|
st.title("Gadsby - Constrained Text Generation with Transformers") |
|
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg") |
|
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)") |
|
|
|
|
|
|
|
form = st.sidebar.form("choose_settings") |
|
form.header("Model Settings") |
|
|
|
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "TheBloke/vicuna-7B-1.1-HF") |
|
form.caption("This will download a new model, so it may take awhile or even break if the model is too large") |
|
percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], ) |
|
form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced") |
|
|
|
form.header("Token Level Constraint Settings") |
|
form.subheader("Lipogram Constraint") |
|
form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged") |
|
filter_mode = form.checkbox("Filter Mode?", value=False) |
|
form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity") |
|
naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e") |
|
factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99") |
|
|
|
form.header("Sequence Level Constraint Settings") |
|
form.header("Phrasal Constraint") |
|
force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram") |
|
form.header("Disjunctive Constraint") |
|
force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]') |
|
|
|
if force_flexible_input: |
|
try: |
|
force_flexible = ast.literal_eval(force_flexible_input) |
|
except Exception as e: |
|
st.write('Failed to parse the list. Please check your input.') |
|
st.write('Error:', e) |
|
force_flexible = [] |
|
else: |
|
pass |
|
|
|
|
|
if naughty_strings_list: |
|
chars = naughty_strings_list.split(',') |
|
factors = list(map(float, factor_input.split(','))) |
|
chars_to_modify = dict(zip(chars, factors)) |
|
else: |
|
chars = "" |
|
factors = [] |
|
chars_to_modify = {} |
|
|
|
generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}') |
|
st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate") |
|
|
|
|
|
custom_prompt = """ |
|
### Human: Write about how much you love constrained text generation techniques |
|
|
|
### Assistant: |
|
""" |
|
|
|
sequence = st.text_area("Enter a custom prompt", value = custom_prompt) |
|
|
|
form.form_submit_button("Generate some Constrained Text!") |
|
|
|
def parse_generate_args(args_str): |
|
args_list = args_str.split(',') |
|
args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2} |
|
return args_dict |
|
|
|
@st.cache_resource |
|
def load_the_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False) |
|
return tokenizer |
|
|
|
@st.cache_resource |
|
def load_the_model(percision): |
|
if percision == "32bit": |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False) |
|
elif percision =="16bit": |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16) |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True) |
|
return model |
|
|
|
if len(chars) != len(factors): |
|
st.write("Please ensure that the number of characters matches the number of factors.") |
|
else: |
|
model = load_the_model(percision) |
|
tokenizer = load_the_tokenizer() |
|
constraints = [] |
|
if force_word: |
|
constraints.append(PhrasalConstraint( |
|
tokenizer(force_word, add_special_tokens=False).input_ids |
|
)) |
|
if force_flexible_input: |
|
constraints.append(DisjunctiveConstraint( |
|
tokenizer(force_flexible, add_special_tokens=False).input_ids |
|
)) |
|
if filter_mode: |
|
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)]) |
|
else: |
|
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)]) |
|
input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda') |
|
generate_kwargs = ast.literal_eval(generate_args) |
|
if constraints: |
|
output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs) |
|
else: |
|
output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs) |
|
st.write("GENERATED SEQUENCE(s): ") |
|
for output in output_ids: |
|
st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True)) |
|
|
|
|
|
|
|
|