Spaces:
Sleeping
Sleeping
""" | |
This file contains the code to generate paraphrases of sentences. | |
""" | |
import os | |
import sys | |
import logging | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from tqdm import tqdm # for progress bars | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from utils.config import load_config | |
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml') | |
# config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase'] | |
# Configure logging to show only warnings or above on the terminal. | |
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
class Paraphraser: | |
""" | |
Paraphraser class to generate paraphrases of sentences. | |
""" | |
def __init__(self, config): | |
self.config = config | |
import torch | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tqdm.write(f"[Paraphraser] Initializing on device: {self.device}") | |
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device) | |
self.num_beams = config['num_beams'] | |
self.num_beam_groups = config['num_beam_groups'] | |
self.num_return_sequences = config['num_return_sequences'] | |
self.repetition_penalty = config['repetition_penalty'] | |
self.diversity_penalty = config['diversity_penalty'] | |
self.no_repeat_ngram_size = config['no_repeat_ngram_size'] | |
self.temperature = config['temperature'] | |
self.max_length = config['max_length'] | |
def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None): | |
tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}") | |
if num_return_sequences is None: | |
num_return_sequences = self.num_return_sequences | |
if num_beams is None: | |
num_beams = self.num_beams | |
if num_beam_groups is None: | |
num_beam_groups = self.num_beam_groups | |
inputs = self.tokenizer.encode("paraphrase: " + sentence, | |
return_tensors="pt", | |
max_length=self.max_length, | |
truncation=True).to(self.device) | |
outputs = self.model.generate( | |
inputs, | |
max_length=self.max_length, | |
num_beams=num_beams, | |
num_beam_groups=num_beam_groups, | |
num_return_sequences=num_return_sequences, | |
repetition_penalty=self.repetition_penalty, | |
diversity_penalty=self.diversity_penalty, | |
no_repeat_ngram_size=self.no_repeat_ngram_size, | |
temperature=self.temperature | |
) | |
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) | |
for output in tqdm(outputs, desc="Decoding Paraphrases")] | |
tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.") | |
return paraphrases | |
if __name__ == "__main__": | |
config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml' | |
config = load_config(config_path) | |
paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase']) | |
sentence = "The quick brown fox jumps over the lazy dog." | |
paraphrases = paraphraser.paraphrase(sentence) | |
for paraphrase in paraphrases: | |
print(paraphrase) |