Spaces:
Runtime error
Runtime error
| """ | |
| grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output. | |
| """ | |
| from datetime import datetime | |
| import os | |
| import pprint as pp | |
| from neuspell import BertChecker, SclstmChecker | |
| import neuspell | |
| import math | |
| from cleantext import clean | |
| import time | |
| import re | |
| import sys | |
| from symspellpy.symspellpy import SymSpell | |
| import transformers | |
| from transformers import pipeline | |
| from utils import suppress_stdout | |
| def detect_propers(text: str): | |
| """ | |
| detect_propers - detect if a string contains proper nouns | |
| Args: | |
| text (str): [string to be checked] | |
| Returns: | |
| [bool]: [True if string contains proper nouns] | |
| """ | |
| pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*") | |
| return bool(pat.search(text)) | |
| def fix_punct_spaces(string): | |
| """ | |
| fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there" | |
| Parameters | |
| ---------- | |
| string : str, required, input string to be corrected | |
| Returns | |
| ------- | |
| str, corrected string | |
| """ | |
| fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*") | |
| string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string) | |
| return string.strip() | |
| def split_sentences(text: str): | |
| """ | |
| split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft | |
| Args: | |
| text (str): [string to be split] | |
| Returns: | |
| [list]: [list of strings] | |
| """ | |
| return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text) | |
| def remove_repeated_words(bot_response): | |
| """ | |
| remove_repeated_words - remove repeated words from a string, returning only the first instance of each word | |
| Parameters | |
| ---------- | |
| bot_response : str | |
| string to remove repeated words from | |
| Returns | |
| ------- | |
| str | |
| string containing the first instance of each word | |
| """ | |
| words = bot_response.split() | |
| unique_words = [] | |
| for word in words: | |
| if word not in unique_words: | |
| unique_words.append(word) | |
| return " ".join(unique_words) | |
| def remove_trailing_punctuation(text: str, fuLL_strip=False): | |
| """ | |
| remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users | |
| Args: | |
| text (str): [string to be cleaned] | |
| Returns: | |
| [str]: [cleaned string] | |
| """ | |
| if fuLL_strip: | |
| return text.strip("?!.,;:") | |
| else: | |
| return text.strip(".,;:") | |
| def fix_punct_spacing(text: str): | |
| fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*") | |
| spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text) | |
| cln_text = re.sub(r"(\W)(?=\1)", "", spc_text) | |
| return cln_text | |
| def synthesize_grammar( | |
| corrector: transformers.pipeline, | |
| message: str, | |
| num_beams=4, | |
| length_penalty=0.9, | |
| repetition_penalty=1.5, | |
| no_repeat_ngram_size=4, | |
| verbose=False, | |
| ): | |
| """ | |
| synthesize_grammar - use a SyntaxSynthesizer model to generate a string from a message | |
| Parameters | |
| ---------- | |
| corrector : transformers.pipeline, required, which is the SyntaxSynthesizer model already loaded | |
| message : str, required, which is the message to be corrected | |
| num_beams : int, optional, by default 4, which is the number of beams to use for the model | |
| length_penalty : float, optional, by default 0.9, which is the length penalty to use for the model | |
| repetition_penalty : float, optional, by default 1.5, which is the repetition penalty to use for the model | |
| no_repeat_ngram_size : int, optional, by default 4, which is the n-gram size to use for the model | |
| verbose : bool, optional, by default False, which is whether to print the runtime of the model | |
| Returns | |
| ------- | |
| """ | |
| st = time.perf_counter() | |
| input_text = clean(message, lower=False) | |
| results = corrector( | |
| input_text, | |
| max_length=int(1.1 * len(input_text)), | |
| min_length=2 if len(input_text) < 64 else int(0.2 * len(input_text)), | |
| num_beams=num_beams, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| early_stopping=True, | |
| do_sample=False, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| corrected_text = results[0]["generated_text"] | |
| if verbose: | |
| rt = round(time.perf_counter() - st, 2) | |
| print(f"synthesizing took {rt} seconds") | |
| return corrected_text.strip() | |
| """ | |
| start of SymSpell code | |
| """ | |
| def symspeller( | |
| my_string: str, | |
| sym_checker=None, | |
| max_dist: int = 2, | |
| prefix_length: int = 7, | |
| ignore_non_words=True, | |
| dictionary_path: str = None, | |
| bigram_path: str = None, | |
| verbose=False, | |
| ): | |
| """ | |
| symspeller - a wrapper for the SymSpell class from symspellpy | |
| Parameters | |
| ---------- | |
| my_string : str, required, default=None, the string to be checked | |
| sym_checker : SymSpell, optional, default=None, the SymSpell object to use | |
| max_dist : int, optional, default=3, the maximum distance to look for replacements | |
| prefix_length : int, optional, default=7, the length of the prefixes to use | |
| ignore_non_words : bool, optional, default=True, whether to ignore non-words | |
| dictionary_path : str, optional, default=None, the path to the dictionary file | |
| bigram_path : str, optional, default=None, the path to the bigram dictionary file | |
| verbose : bool, optional, default=False, whether to print the results | |
| Returns | |
| ------- | |
| list, | |
| """ | |
| assert len(my_string) > 0, "entered string for correction is empty" | |
| if sym_checker is None: | |
| # need to create a new class object. user can specify their own dictionary and bigram files | |
| if verbose: | |
| print("creating new SymSpell object") | |
| sym_checker = build_symspell_obj( | |
| edit_dist=max_dist, | |
| prefix_length=prefix_length, | |
| dictionary_path=dictionary_path, | |
| bigram_path=bigram_path, | |
| ) | |
| else: | |
| if verbose: | |
| print("using existing SymSpell object") | |
| # max edit distance per lookup (per single word, not per whole input string) | |
| suggestions = sym_checker.lookup_compound( | |
| my_string, | |
| max_edit_distance=max_dist, | |
| ignore_non_words=ignore_non_words, | |
| ignore_term_with_digits=True, | |
| transfer_casing=True, | |
| ) | |
| if verbose: | |
| print(f"{len(suggestions)} suggestions found") | |
| print(f"the original string is:\n\t{my_string}") | |
| sug_list = [sug.term for sug in suggestions] | |
| print(f"suggestions:\n\t{sug_list}\n") | |
| if len(suggestions) < 1: | |
| return clean(my_string) # no correction because no suggestions | |
| else: | |
| first_result = suggestions[0] # first result is the most likely | |
| return first_result._term | |
| def build_symspell_obj( | |
| edit_dist=2, | |
| prefix_length=7, | |
| dictionary_path=None, | |
| bigram_path=None, | |
| ): | |
| """ | |
| build_symspell_obj [build a SymSpell object] | |
| Args: | |
| verbose (bool, optional): Defaults to False. | |
| Returns: | |
| SymSpell: a SymSpell object | |
| """ | |
| dictionary_path = ( | |
| r"symspell_rsc/frequency_dictionary_en_82_765.txt" | |
| if dictionary_path is None | |
| else dictionary_path | |
| ) | |
| bigram_path = ( | |
| r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" | |
| if bigram_path is None | |
| else bigram_path | |
| ) | |
| sym_checker = SymSpell( | |
| max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length | |
| ) | |
| # term_index is the column of the term and count_index is the | |
| # column of the term frequency | |
| sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1) | |
| sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2) | |
| return sym_checker | |
| """ | |
| # if using t5b_correction to check for spelling errors, use this code to initialize the objects | |
| import torch | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| model_name = 'deep-learning-analytics/GrammarCorrector' | |
| # torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| torch_device = 'cpu' | |
| gc_tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device) | |
| """ | |
| def t5b_correction(prompt: str, korrektor, verbose=False, beams=4): | |
| """ | |
| t5b_correction - correct a string using a text2textgen pipeline model from transformers | |
| Parameters | |
| ---------- | |
| prompt : str, required, input prompt to be corrected | |
| korrektor : transformers.pipeline, required, pipeline object | |
| verbose : bool, optional, whether to print the corrected prompt. Defaults to False. | |
| beams : int, optional, number of beams to use for the correction. Defaults to 4. | |
| Returns | |
| ------- | |
| str, corrected prompt | |
| """ | |
| p_min_len = int(math.ceil(0.9 * len(prompt))) | |
| p_max_len = int(math.ceil(1.1 * len(prompt))) | |
| if verbose: | |
| print(f"setting min to {p_min_len} and max to {p_max_len}\n") | |
| gcorr_result = korrektor( | |
| f"grammar: {prompt}", | |
| return_text=True, | |
| clean_up_tokenization_spaces=True, | |
| num_beams=beams, | |
| max_length=p_max_len, | |
| repetition_penalty=1.3, | |
| length_penalty=0.2, | |
| no_repeat_ngram_size=2, | |
| ) | |
| if verbose: | |
| print(f"grammar correction result: \n\t{gcorr_result}\n") | |
| return gcorr_result | |
| def all_neuspell_chkrs(): | |
| """ | |
| disp_neuspell_chkrs - display the neuspell checkers available | |
| Parameters | |
| ---------- | |
| None | |
| Returns | |
| ------- | |
| checker_opts - list of checkers available | |
| """ | |
| checker_opts = dir(neuspell) | |
| print(f"\navailable checkers:") | |
| pp.pprint(checker_opts, indent=4, compact=True) | |
| return checker_opts | |
| def load_ns_checker(customckr=None, fast=False): | |
| """ | |
| load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers | |
| Args: | |
| customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker | |
| Returns: | |
| [neuspell.NeuSpell]: [neuspell checker object] | |
| """ | |
| st = time.perf_counter() | |
| # stop all printing to the console | |
| with suppress_stdout(): | |
| if customckr is None and not fast: | |
| checker = BertChecker( | |
| pretrained=True | |
| ) # load the default checker, has the best balance | |
| elif customckr is None and fast: | |
| checker = SclstmChecker( | |
| pretrained=True | |
| ) # this one is faster but not as accurate | |
| else: | |
| checker = customckr(pretrained=True) | |
| rt_min = (time.perf_counter() - st) / 60 | |
| # return to standard logging level | |
| print(f"\n\nloaded checker in {rt_min} minutes") | |
| return checker | |
| def neuspell_correct(input_text: str, checker=None, verbose=False): | |
| """ | |
| neuspell_correct - correct a string using neuspell. | |
| note that modificaitons to the checker are needed if doing list-based corrections | |
| Parameters | |
| ---------- | |
| input_text : str, required, input string to be corrected | |
| checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None. | |
| verbose : bool, optional, whether to print the corrected string. Defaults to False. | |
| Returns | |
| ------- | |
| str, corrected string | |
| """ | |
| if isinstance(input_text, str) and len(input_text) < 4: | |
| print(f"input text of {input_text} is too short to be corrected") | |
| return input_text | |
| if checker is None: | |
| print("NOTE - no checker provided, loading default checker") | |
| checker = SclstmChecker(pretrained=True) | |
| corrected = checker.correct(input_text) | |
| cleaned_txt = fix_punct_spaces(corrected) | |
| if verbose: | |
| print(f"neuspell correction result: \n\t{cleaned_txt}\n") | |
| return cleaned_txt | |
| def grammarpipe(corrector, qphrase: str): | |
| """ | |
| gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED. | |
| Idea is to correct a string using a text2textgen pipeline model from transformers | |
| Args: | |
| corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model] | |
| qphrase (str): [text to be corrected] | |
| Returns: | |
| [str]: [corrected text] | |
| """ | |
| if isinstance(qphrase, str) and len(qphrase) < 4: | |
| print(f"input text of {qphrase} is too short to be corrected") | |
| return qphrase | |
| try: | |
| corrected = corrector( | |
| clean(qphrase), return_text=True, clean_up_tokenization_spaces=True | |
| ) | |
| return corrected[0]["generated_text"] | |
| except Exception as e: | |
| print(f"NOTE - failed to correct with grammarpipe:\n {e}") | |
| return clean(qphrase) | |
| def DLA_correct(qphrase: str): | |
| """ | |
| DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually | |
| Args: | |
| qphrase (str): [string to be corrected] | |
| Returns: | |
| str, the list of the corrected strings joined under " " | |
| """ | |
| if isinstance(qphrase, str) and len(qphrase) < 4: | |
| print(f"input text of {qphrase} is too short to be corrected") | |
| return qphrase | |
| sentences = split_sentences(qphrase) | |
| if len(sentences) == 1: | |
| corrected = correct_grammar(sentences[0]) | |
| return corrected | |
| else: | |
| full_cor = [] | |
| for sen in sentences: | |
| corr_sen = correct_grammar(clean(sen)) | |
| full_cor.append(corr_sen) | |
| return " ".join(full_cor) | |
| def correct_grammar( | |
| input_text: str, | |
| tokenizer, | |
| model, | |
| n_results: int = 1, | |
| beams: int = 8, | |
| temp=1, | |
| uniq_ngrams=2, | |
| rep_penalty=1.5, | |
| device="cpu", | |
| ): | |
| """ | |
| correct_grammar - correct a string using a text2textgen pipeline model from transformers. | |
| This function is an alternative to the t5b_correction function. | |
| Parameters | |
| ---------- | |
| input_text : str, required, input string to be corrected | |
| tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model | |
| model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model | |
| n_results : int, optional, number of results to return. Defaults to 1. | |
| beams : int, optional, number of beams to use for the correction. Defaults to 8. | |
| temp : int, optional, temperature to use for the correction. Defaults to 1. | |
| uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2. | |
| rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5. | |
| device : str, optional, device to use for the correction. Defaults to 'cpu'. | |
| Returns | |
| ------- | |
| str, corrected string (or list of strings if n_results > 1) | |
| """ | |
| st = time.perf_counter() | |
| if len(input_text) < 5: | |
| return input_text | |
| max_length = min(int(math.ceil(len(input_text) * 1.2)), 128) | |
| batch = tokenizer( | |
| [input_text], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ).to(device) | |
| translated = model.generate( | |
| **batch, | |
| max_length=max_length, | |
| min_length=min(10, len(input_text)), | |
| no_repeat_ngram_size=uniq_ngrams, | |
| repetition_penalty=rep_penalty, | |
| num_beams=beams, | |
| num_return_sequences=n_results, | |
| temperature=temp, | |
| ) | |
| tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) | |
| rt_min = (time.perf_counter() - st) / 60 | |
| print(f"\n\ncorrected in {rt_min} minutes") | |
| if isinstance(tgt_text, list): | |
| return tgt_text[0] | |
| else: | |
| return tgt_text | |