Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english:""" | |
| def get_model(): | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln2") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln21") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/Points3") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo1.3BPointsLincolnFormalInformal") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPTNeo350MInformalToFormalLincoln7") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincolnConciseWordy") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln2") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/MediumInformalToFormalLincoln3") | |
| #model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints2") | |
| model = AutoModelForCausalLM.from_pretrained("BigSalmon/GPT2Neo1.3BPoints3") | |
| tokenizer = AutoTokenizer.from_pretrained("BigSalmon/Points2") | |
| return model, tokenizer | |
| model, tokenizer = get_model() | |
| st.text('''For Prompt Templates: https://huggingface.co/BigSalmon/InformalToFormalLincoln35''') | |
| temp = st.sidebar.slider("Temperature", 0.7, 1.5) | |
| number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50) | |
| lengths = st.sidebar.slider("Length", 3, 10) | |
| bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ") | |
| logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300) | |
| def run_generate(text, bad_words): | |
| yo = [] | |
| input_ids = tokenizer.encode(text, return_tensors='pt') | |
| res = len(tokenizer.encode(text)) | |
| bad_words = bad_words.split() | |
| bad_word_ids = [] | |
| for bad_word in bad_words: | |
| bad_word = " " + bad_word | |
| ids = tokenizer(bad_word).input_ids | |
| bad_word_ids.append(ids) | |
| sample_outputs = model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length= res + lengths, | |
| min_length = res + lengths, | |
| top_k=50, | |
| temperature=temp, | |
| num_return_sequences=number_of_outputs, | |
| bad_words_ids=bad_word_ids | |
| ) | |
| for i in range(number_of_outputs): | |
| e = tokenizer.decode(sample_outputs[i]) | |
| e = e.replace(text, "") | |
| yo.append(e) | |
| return yo | |
| with st.form(key='my_form'): | |
| text = st.text_area(label='Enter sentence', value=first) | |
| submit_button = st.form_submit_button(label='Submit') | |
| submit_button2 = st.form_submit_button(label='Submit Log Probs') | |
| if submit_button: | |
| translated_text = run_generate(text, bad_words) | |
| st.write(translated_text if translated_text else "No translation found") | |
| if submit_button2: | |
| with torch.no_grad(): | |
| text2 = str(text) | |
| print(text2) | |
| text3 = tokenizer.encode(text2) | |
| myinput, past_key_values = torch.tensor([text3]), None | |
| myinput = myinput | |
| logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) | |
| logits = logits[0,-1] | |
| probabilities = torch.nn.functional.softmax(logits) | |
| best_logits, best_indices = logits.topk(logs_outputs) | |
| best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] | |
| st.write(best_words) |