Spaces:
Runtime error
Runtime error
| import logging | |
| from pathlib import Path | |
| import os | |
| import re | |
| import gradio as gr | |
| import nltk | |
| import torch | |
| from cleantext import clean | |
| from summarize import load_model_and_tokenizer, summarize_via_tokenbatches | |
| _here = Path(__file__).parent | |
| nltk.download("stopwords") # TODO=find where this requirement originates from | |
| import transformers | |
| transformers.logging.set_verbosity_error() | |
| logging.basicConfig() | |
| def truncate_word_count(text, max_words=512): | |
| """ | |
| truncate_word_count - a helper function for the gradio module | |
| Parameters | |
| ---------- | |
| text : str, required, the text to be processed | |
| max_words : int, optional, the maximum number of words, default=512 | |
| Returns | |
| ------- | |
| dict, the text and whether it was truncated | |
| """ | |
| # split on whitespace with regex | |
| words = re.split(r"\s+", text) | |
| processed = {} | |
| if len(words) > max_words: | |
| processed["was_truncated"] = True | |
| processed["truncated_text"] = " ".join(words[:max_words]) | |
| else: | |
| processed["was_truncated"] = False | |
| processed["truncated_text"] = text | |
| return processed | |
| def proc_submission( | |
| input_text: str, | |
| num_beams, | |
| length_penalty, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| token_batch_length, | |
| max_input_length: int = 512, | |
| ): | |
| """ | |
| proc_submission - a helper function for the gradio module | |
| Parameters | |
| ---------- | |
| input_text : str, required, the text to be processed | |
| max_input_length : int, optional, the maximum length of the input text, default=512 | |
| Returns | |
| ------- | |
| str of HTML, the interactive HTML form for the model | |
| """ | |
| settings = { | |
| "length_penalty": length_penalty, | |
| "repetition_penalty": repetition_penalty, | |
| "no_repeat_ngram_size": no_repeat_ngram_size, | |
| "encoder_no_repeat_ngram_size": 4, | |
| "num_beams": num_beams, | |
| } | |
| history = {} | |
| clean_text = clean(input_text, lower=False) | |
| processed = truncate_word_count(clean_text, max_input_length) | |
| if processed["was_truncated"]: | |
| history["input_text"] = processed["truncated_text"] | |
| history["was_truncated"] = True | |
| msg = f"Input text was truncated to {max_input_length} characters." | |
| logging.warning(msg) | |
| history["WARNING"] = msg | |
| else: | |
| history["input_text"] = input_text | |
| history["was_truncated"] = False | |
| _summaries = summarize_via_tokenbatches( | |
| history["input_text"], | |
| model, | |
| tokenizer, | |
| batch_length=token_batch_length, | |
| **settings, | |
| ) | |
| sum_text = [s["summary"][0] for s in _summaries] | |
| sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries] | |
| history["Input"] = input_text | |
| history["Summary Text"] = "\n\t".join(sum_text) | |
| history["Summary Scores"] = "\n".join(sum_scores) | |
| html = "" | |
| for name, item in history.items(): | |
| html += ( | |
| f"<h2>{name}:</h2><hr><b>{item}</b><br><br>" | |
| if "summary" not in name.lower() | |
| else f"<h2>{name}:</h2><hr><b>{item}</b>" | |
| ) | |
| html += "" | |
| return html | |
| def load_examples(examples_dir="examples"): | |
| src = _here / examples_dir | |
| src.mkdir(exist_ok=True) | |
| examples = [f for f in src.glob("*.txt")] | |
| # load the examples into a list | |
| text_examples = [] | |
| for example in examples: | |
| with open(example, "r") as f: | |
| text = f.read() | |
| text_examples.append([text, 4, 2048, 0.7, 3.5, 3]) | |
| return text_examples | |
| if __name__ == "__main__": | |
| model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary") | |
| title = "Long-form text summarization with LED on the BookSumm dataset" | |
| description = ( | |
| "This is a simple example of using the LED model to summarize a long-form text." | |
| ) | |
| gr.Interface( | |
| proc_submission, | |
| inputs=[ | |
| gr.inputs.Textbox(lines=10, label="input text"), | |
| gr.inputs.Slider( | |
| minimum=4, maximum=10, label="num_beams", default=4, step=1 | |
| ), | |
| gr.inputs.Slider( | |
| minimum=512, maximum=4096, label="token_batch_length", default=2048, step=512, | |
| ), | |
| gr.inputs.Slider( | |
| minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05 | |
| ), | |
| gr.inputs.Slider( | |
| minimum=1.0, | |
| maximum=5.0, | |
| label="repetition_penalty", | |
| default=3.5, | |
| step=0.1, | |
| ), | |
| gr.inputs.Slider( | |
| minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1 | |
| ), | |
| ], | |
| outputs="html", | |
| examples_per_page=4, | |
| title=title, | |
| description=description, | |
| examples=load_examples(), | |
| ).launch(enable_queue=True, share=True) | |