storybook / app.py
ussarata's picture
Create app.py
e19273a
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import gradio as gr
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
from tqdm import tqdm, trange
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
def generate_text(prompt, sentence_length):
# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Generate text
generated_text = model.generate(
input_ids,
max_length=sentence_length,
min_new_tokens = 7,
do_sample=True,
top_p=0.95,
top_k=50,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated text
generated_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
return generated_text
def split_by_sentences(generated_text):
# Split text by '.' character
sentences = sent_tokenize(generated_text)
# Remove incomplete sentences
filtered_sentences = [sentence for sentence in sentences if sentence.endswith('.')]
# Convert array of sentences into string and add newline
filtered_sentences = '\n'.join([sentence for sentence in filtered_sentences])
return filtered_sentences
description = 'This is a text generation model which uses transformers to generate text given an prompt'
title = 'Story Generator'
text_gen_UI = gr.Interface(generate_text,
inputs=['text', gr.inputs.Slider(0,100, label='Number of generated words')],
outputs=[gr.outputs.Textbox()])
text_clean_UI = gr.Interface(split_by_sentences,
inputs=['text'],
outputs=[gr.outputs.Textbox()])
gr.Series(text_gen_UI, text_clean_UI, description=description, title=title).launch()