File size: 3,523 Bytes
7bc1fb2
414d64e
8784af6
 
b78cea5
8784af6
 
 
e234712
fc092d3
 
df899db
3d73dee
fe2f37b
80ab8fe
 
8784af6
414d64e
2473dde
8784af6
 
80ab8fe
8784af6
 
 
 
 
80ab8fe
8784af6
 
 
2966e63
8784af6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e234712
8784af6
 
 
e234712
8784af6
 
 
 
 
 
 
 
 
80ab8fe
8784af6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english:"""

@st.cache(allow_output_mutation=True)
def get_model():
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln2")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln21")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/Points3")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo1.3BPointsLincolnFormalInformal")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln7")
    #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincolnConciseWordy")
    model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln2")
    tokenizer = AutoTokenizer.from_pretrained("BigSalmon/Points2")
    return model, tokenizer
    
model, tokenizer = get_model()

st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln35''')

temp = st.sidebar.slider("Temperature", 0.7, 1.5)
number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50)
lengths = st.sidebar.slider("Length", 3, 10)
bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")
logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300)

def run_generate(text, bad_words):
  yo = []
  input_ids = tokenizer.encode(text, return_tensors='pt')
  res = len(tokenizer.encode(text))
  bad_words = bad_words.split()
  bad_word_ids = []
  for bad_word in bad_words: 
    bad_word = " " + bad_word
    ids = tokenizer(bad_word).input_ids
    bad_word_ids.append(ids)
  sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length= res + lengths, 
    min_length = res + lengths, 
    top_k=50,
    temperature=temp,
    num_return_sequences=number_of_outputs,
    bad_words_ids=bad_word_ids
  )
  for i in range(number_of_outputs):
    e = tokenizer.decode(sample_outputs[i])
    e = e.replace(text, "")
    yo.append(e)
  return yo
with st.form(key='my_form'):
    text = st.text_area(label='Enter sentence', value=first)
    submit_button = st.form_submit_button(label='Submit')
    submit_button2 = st.form_submit_button(label='Submit Log Probs')
    if submit_button:
      translated_text = run_generate(text, bad_words)
      st.write(translated_text if translated_text else "No translation found")
    if submit_button2:
      with torch.no_grad():
        text2 = str(text)
        print(text2)
        text3 = tokenizer.encode(text2)
        myinput, past_key_values = torch.tensor([text3]), None
        myinput = myinput
        logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
        logits = logits[0,-1]
        probabilities = torch.nn.functional.softmax(logits)
        best_logits, best_indices = logits.topk(logs_outputs)
        best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]      
        st.write(best_words)