Spaces:
Sleeping
Sleeping
import random | |
import torch | |
from transformers import BertTokenizer, BertForMaskedLM | |
from nltk.corpus import stopwords | |
import nltk | |
# Ensure stopwords are downloaded | |
try: | |
nltk.data.find('corpora/stopwords') | |
except LookupError: | |
nltk.download('stopwords') | |
class MaskingProcessor: | |
def __init__(self): | |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") | |
self.stop_words = set(stopwords.words('english')) | |
def remove_stopwords(self, words): | |
""" | |
Remove stopwords from the given list of words. | |
Args: | |
words (list): List of words. | |
Returns: | |
list: List of non-stop words. | |
""" | |
return [word for word in words if word.lower() not in self.stop_words] | |
def adjust_ngram_indices(self, original_words, common_ngrams): | |
""" | |
Adjust indices of common n-grams after removing stopwords. | |
Args: | |
original_words (list): Original list of words. | |
common_ngrams (dict): Common n-grams and their indices. | |
Returns: | |
dict: Adjusted common n-grams with updated indices. | |
""" | |
non_stop_words = self.remove_stopwords(original_words) | |
original_to_non_stop = [] | |
non_stop_idx = 0 | |
for original_idx, word in enumerate(original_words): | |
if word.lower() not in self.stop_words: | |
original_to_non_stop.append((original_idx, non_stop_idx)) | |
non_stop_idx += 1 | |
adjusted_ngrams = {} | |
for ngram, positions in common_ngrams.items(): | |
adjusted_positions = [] | |
for start, end in positions: | |
try: | |
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start) | |
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end) | |
adjusted_positions.append((new_start, new_end)) | |
except StopIteration: | |
continue # Skip if indices cannot be mapped | |
adjusted_ngrams[ngram] = adjusted_positions | |
return adjusted_ngrams | |
def mask_sentence_random(self, sentence, common_ngrams): | |
""" | |
Mask words in the sentence based on the specified rules after removing stopwords. | |
""" | |
original_words = sentence.split() | |
print(f' ---- original_words : {original_words} ----- ') | |
non_stop_words = self.remove_stopwords(original_words) | |
print(f' ---- non_stop_words : {non_stop_words} ----- ') | |
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
print(f' ---- common_ngrams : {common_ngrams} ----- ') | |
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ') | |
mask_indices = [] | |
# Extract n-gram positions in non-stop words | |
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
# Mask a word before the first common n-gram | |
if ngram_positions: | |
print(f' ---- ngram_positions : {ngram_positions} ----- ') | |
first_ngram_start = ngram_positions[0][0] | |
print(f' ---- first_ngram_start : {first_ngram_start} ----- ') | |
if first_ngram_start > 0: | |
mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ') | |
mask_indices.append(mask_index_before_ngram) | |
# Mask words between common n-grams | |
for i in range(len(ngram_positions) - 1): | |
end_prev = ngram_positions[i][1] | |
print(f' ---- end_prev : {end_prev} ----- ') # END INDICE FROM PREV LOOP FUNKNLKNLKNLKNLKNLKNLSKDNFLKSDHJFLSDJKFH:KLSDHF:LHKSDF:HJKLDFS:HJKLDFSHJK: | |
start_next = ngram_positions[i + 1][0] | |
print(f' ---- start_next : {start_next} ----- ') | |
if start_next > end_prev + 1: | |
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ') | |
mask_indices.append(mask_index_between_ngrams) | |
# Mask a word after the last common n-gram | |
last_ngram_end = ngram_positions[-1][1] | |
if last_ngram_end < len(non_stop_words) - 1: | |
print(f' ---- last_ngram_end : {last_ngram_end} ----- ') | |
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ') | |
mask_indices.append(mask_index_after_ngram) | |
# Create mapping from non-stop words to original indices | |
non_stop_to_original = {} | |
non_stop_idx = 0 | |
for orig_idx, word in enumerate(original_words): | |
if word.lower() not in self.stop_words: | |
non_stop_to_original[non_stop_idx] = orig_idx | |
non_stop_idx += 1 | |
# Map mask indices from non-stop word positions to original positions | |
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ') | |
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
print(f' ---- original_mask_indices : {original_mask_indices} ----- ') | |
# Apply masks to the original sentence | |
masked_words = original_words.copy() | |
for idx in original_mask_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
return " ".join(masked_words) | |
def mask_sentence_pseudorandom(self, sentence, common_ngrams): | |
""" | |
Mask words in the sentence based on the specified rules after removing stopwords. | |
""" | |
random.seed(42) | |
original_words = sentence.split() | |
print(f' ---- original_words : {original_words} ----- ') | |
non_stop_words = self.remove_stopwords(original_words) | |
print(f' ---- non_stop_words : {non_stop_words} ----- ') | |
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
print(f' ---- common_ngrams : {common_ngrams} ----- ') | |
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ') | |
mask_indices = [] | |
# Extract n-gram positions in non-stop words | |
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
# Mask a word before the first common n-gram | |
if ngram_positions: | |
print(f' ---- ngram_positions : {ngram_positions} ----- ') | |
first_ngram_start = ngram_positions[0][0] | |
print(f' ---- first_ngram_start : {first_ngram_start} ----- ') | |
if first_ngram_start > 0: | |
mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ') | |
mask_indices.append(mask_index_before_ngram) | |
# Mask words between common n-grams | |
for i in range(len(ngram_positions) - 1): | |
end_prev = ngram_positions[i][1] | |
print(f' ---- end_prev : {end_prev} ----- ') | |
start_next = ngram_positions[i + 1][0] | |
print(f' ---- start_next : {start_next} ----- ') | |
if start_next > end_prev + 1: | |
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ') | |
mask_indices.append(mask_index_between_ngrams) | |
# Mask a word after the last common n-gram | |
last_ngram_end = ngram_positions[-1][1] | |
if last_ngram_end < len(non_stop_words) - 1: | |
print(f' ---- last_ngram_end : {last_ngram_end} ----- ') | |
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ') | |
mask_indices.append(mask_index_after_ngram) | |
# Create mapping from non-stop words to original indices | |
non_stop_to_original = {} | |
non_stop_idx = 0 | |
for orig_idx, word in enumerate(original_words): | |
if word.lower() not in self.stop_words: | |
non_stop_to_original[non_stop_idx] = orig_idx | |
non_stop_idx += 1 | |
# Map mask indices from non-stop word positions to original positions | |
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ') | |
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
print(f' ---- original_mask_indices : {original_mask_indices} ----- ') | |
# Apply masks to the original sentence | |
masked_words = original_words.copy() | |
for idx in original_mask_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
return " ".join(masked_words) | |
def calculate_word_entropy(self, sentence, word_position): | |
""" | |
Calculate entropy for a specific word position in the sentence. | |
Args: | |
sentence (str): The input sentence | |
word_position (int): Position of the word to calculate entropy for | |
Returns: | |
float: Entropy value for the word | |
""" | |
words = sentence.split() | |
masked_words = words.copy() | |
masked_words[word_position] = self.tokenizer.mask_token | |
masked_sentence = " ".join(masked_words) | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
# Get probabilities for the masked position | |
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) | |
# Calculate entropy: -sum(p * log(p)) | |
entropy = -torch.sum(probs * torch.log(probs + 1e-9)) | |
return entropy.item() | |
def mask_sentence_entropy(self, sentence, common_ngrams): | |
""" | |
Mask words in the sentence based on entropy, following n-gram positioning rules. | |
Args: | |
sentence (str): Original sentence | |
common_ngrams (dict): Common n-grams and their indices | |
Returns: | |
str: Masked sentence | |
""" | |
original_words = sentence.split() | |
non_stop_words = self.remove_stopwords(original_words) | |
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
# Create mapping from non-stop words to original indices | |
non_stop_to_original = {} | |
original_to_non_stop = {} | |
non_stop_idx = 0 | |
for orig_idx, word in enumerate(original_words): | |
if word.lower() not in self.stop_words: | |
non_stop_to_original[non_stop_idx] = orig_idx | |
original_to_non_stop[orig_idx] = non_stop_idx | |
non_stop_idx += 1 | |
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
mask_indices = [] | |
if ngram_positions: | |
# Handle words before first n-gram | |
first_ngram_start = ngram_positions[0][0] | |
if first_ngram_start > 0: | |
# Calculate entropy for all candidate positions | |
candidate_positions = range(0, first_ngram_start) | |
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
for pos in candidate_positions] | |
# Select position with highest entropy | |
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
# Handle words between n-grams | |
for i in range(len(ngram_positions) - 1): | |
end_prev = ngram_positions[i][1] | |
start_next = ngram_positions[i + 1][0] | |
if start_next > end_prev + 1: | |
candidate_positions = range(end_prev + 1, start_next) | |
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
for pos in candidate_positions] | |
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
# Handle words after last n-gram | |
last_ngram_end = ngram_positions[-1][1] | |
if last_ngram_end < len(non_stop_words) - 1: | |
candidate_positions = range(last_ngram_end + 1, len(non_stop_words)) | |
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
for pos in candidate_positions] | |
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
# Map mask indices to original sentence positions and apply masks | |
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
masked_words = original_words.copy() | |
for idx in original_mask_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
return " ".join(masked_words) | |
def calculate_mask_logits(self, masked_sentence): | |
""" | |
Calculate logits for masked tokens in the sentence using BERT. | |
Args: | |
masked_sentence (str): Sentence with [MASK] tokens. | |
Returns: | |
dict: Masked token indices and their logits. | |
""" | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} | |
return mask_logits | |
def process_sentences(self, sentences, result_dict, method="random"): | |
""" | |
Process sentences and calculate logits for masked tokens. | |
Args: | |
sentences (list): List of sentences | |
result_dict (dict): Dictionary of common n-grams | |
method (str): Masking method ("random" or "entropy") | |
Returns: | |
dict: Masked sentences and logits for each sentence | |
""" | |
results = {} | |
for sentence, ngrams in result_dict.items(): | |
if method == "random": | |
masked_sentence = self.mask_sentence_random(sentence, ngrams) | |
elif method == "pseudorandom": | |
masked_sentence = self.mask_sentence_pseudorandom(sentence, ngrams) | |
else: # entropy | |
masked_sentence = self.mask_sentence_entropy(sentence, ngrams) | |
logits = self.calculate_mask_logits(masked_sentence) | |
results[sentence] = { | |
"masked_sentence": masked_sentence, | |
"mask_logits": logits | |
} | |
return results | |
if __name__ == "__main__": | |
# !!! Working both the cases regardless if the stopword is removed or not | |
sentences = [ | |
"The quick brown fox jumps over the lazy dog everyday.", | |
# "A speedy brown fox jumps over a lazy dog.", | |
# "A swift brown fox leaps over the lethargic dog." | |
] | |
result_dict ={ | |
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} | |
} | |
processor = MaskingProcessor() | |
# results_random = processor.process_sentences(sentences, result_dict) | |
results_entropy = processor.process_sentences(sentences, result_dict, method="random") | |
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False) | |
for sentence, output in results_entropy.items(): | |
print(f"Original Sentence (Random): {sentence}") | |
print(f"Masked Sentence (Random): {output['masked_sentence']}") | |
# print(f"Mask Logits (Random): {output['mask_logits']}") | |
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') | |
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') | |
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}') | |
print('--------------------------------') | |
for mask_idx, logits in output["mask_logits"].items(): | |
print(f"Logits for [MASK] at position {mask_idx}:") | |
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens | |
print(f' len(logits) : {len(logits)}') | |
# ------------------------------------------------------------------------------------------- | |
# def mask_sentence(self, sentence, common_ngrams): | |
# """ | |
# Mask words in the sentence based on the specified rules after removing stopwords. | |
# Args: | |
# sentence (str): Original sentence. | |
# common_ngrams (dict): Common n-grams and their indices. | |
# Returns: | |
# str: Masked sentence. | |
# """ | |
# original_words = sentence.split() | |
# print(f' ---- original_words : {original_words} ----- ') | |
# non_stop_words = self.remove_stopwords(original_words) | |
# print(f' ---- non_stop_words : {non_stop_words} ----- ') | |
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
# print(f' ---- common_ngrams : {common_ngrams} ----- ') | |
# print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ') | |
# mask_indices = [] | |
# # Extract n-gram positions in non-stop words | |
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
# print(f' ---- ngram_positions : {ngram_positions} ----- ') | |
# # Mask a word before the first common n-gram | |
# if ngram_positions: | |
# first_ngram_start = ngram_positions[0][0] | |
# print(f' ---- first_ngram_start : {first_ngram_start} ----- ') | |
# if first_ngram_start > 0: | |
# mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
# print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ') | |
# mask_indices.append(mask_index_before_ngram) | |
# # Mask words between common n-grams | |
# for i in range(len(ngram_positions) - 1): | |
# end_prev = ngram_positions[i][1] | |
# print(f' ---- end_prev : {end_prev} ----- ') | |
# start_next = ngram_positions[i + 1][0] | |
# print(f' ---- start_next : {start_next} ----- ') | |
# if start_next > end_prev + 1: | |
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
# print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ') | |
# mask_indices.append(mask_index_between_ngrams) | |
# # Mask a word after the last common n-gram | |
# last_ngram_end = ngram_positions[-1][1] | |
# print(f' ---- last_ngram_end : {last_ngram_end} ----- ') | |
# if last_ngram_end < len(non_stop_words) - 1: | |
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
# print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ') | |
# mask_indices.append(mask_index_after_ngram) | |
# # Map mask indices back to original sentence | |
# adjusted_indices = [ | |
# orig for orig, non_stop in enumerate(original_words) | |
# if non_stop in mask_indices | |
# ] | |
# # Apply masks to the original sentence | |
# for idx in adjusted_indices: | |
# original_words[idx] = self.tokenizer.mask_token | |
# return " ".join(original_words) | |