liamvbetts's picture
pull random change
5d174b9
raw
history blame
1.39 kB
import gradio as gr
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("liamvbetts/bart-large-cnn-v4")
model = AutoModelForSeq2SeqLM.from_pretrained("liamvbetts/bart-large-cnn-v4")
dataset = load_dataset("cnn_dailymail", "3.0.0")
def summarize(article):
inputs = tokenizer(article, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=128, do_sample=False)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
def get_random_article():
random.seed()
val_example = dataset["validation"].shuffle().select(range(1))
val_article = val_example['article'][0][:512]
return val_article
def update_article_input():
return get_random_article()
# Create Gradio interface
input_text = gr.Textbox(lines=10, label="Input Text", value="")
output_text = gr.Textbox(label="Summary")
load_article_button = gr.Button("Load Random Article")
gr.Interface(
fn=summarize,
inputs=[input_text, load_article_button],
outputs=output_text,
title="News Summary App",
description="Enter a news text and get its summary, or load a random article from the validation set.",
live=True
).add_event("click", update_article_input, inputs=None, outputs=input_text, elem_id=load_article_button.elem_id).launch()