|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
import nltk |
|
nltk.download('punkt') |
|
from nltk.tokenize import sent_tokenize |
|
from tqdm import tqdm, trange |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
def generate_text(prompt, sentence_length): |
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
|
|
generated_text = model.generate( |
|
input_ids, |
|
max_length=sentence_length, |
|
min_new_tokens = 7, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=50, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(generated_text[0], skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
def split_by_sentences(generated_text): |
|
|
|
sentences = sent_tokenize(generated_text) |
|
|
|
filtered_sentences = [sentence for sentence in sentences if sentence.endswith('.')] |
|
|
|
filtered_sentences = '\n'.join([sentence for sentence in filtered_sentences]) |
|
return filtered_sentences |
|
|
|
description = 'This is a text generation model which uses transformers to generate text given an prompt' |
|
title = 'Story Generator' |
|
|
|
text_gen_UI = gr.Interface(generate_text, |
|
inputs=['text', gr.inputs.Slider(0,100, label='Number of generated words')], |
|
outputs=[gr.outputs.Textbox()]) |
|
text_clean_UI = gr.Interface(split_by_sentences, |
|
inputs=['text'], |
|
outputs=[gr.outputs.Textbox()]) |
|
|
|
gr.Series(text_gen_UI, text_clean_UI, description=description, title=title).launch() |