Spaces:
Sleeping
Sleeping
import random | |
import torch | |
from transformers import BertTokenizer, BertForMaskedLM | |
from nltk.corpus import stopwords | |
import nltk | |
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS | |
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
# Ensure stopwords are downloaded | |
try: | |
nltk.data.find('corpora/stopwords') | |
except LookupError: | |
nltk.download('stopwords') | |
class MaskingProcessor: | |
def __init__(self): | |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") | |
self.stop_words = set(stopwords.words('english')) | |
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False): | |
""" | |
Mask one word before the first common n-gram, one between two n-grams, | |
and one after the last common n-gram (random selection). | |
Args: | |
original_sentence (str): Original sentence | |
common_ngrams (dict): Common n-grams and their indices | |
Returns: | |
str: Masked sentence | |
""" | |
if remove_stopwords: | |
words = original_sentence.split() | |
words = [word for word in words if word not in self.stop_words] | |
else: | |
words = original_sentence.split() | |
mask_indices = [] | |
# Handle before the first common n-gram | |
if common_ngrams: | |
first_ngram_start = list(common_ngrams.values())[0][0][0] | |
if first_ngram_start > 0: | |
mask_indices.append(random.randint(0, first_ngram_start - 1)) | |
# Handle between common n-grams | |
ngram_positions = list(common_ngrams.values()) | |
for i in range(len(ngram_positions) - 1): | |
end_prev = ngram_positions[i][-1][1] | |
start_next = ngram_positions[i + 1][0][0] | |
if start_next > end_prev + 1: | |
mask_indices.append(random.randint(end_prev + 1, start_next - 1)) | |
# Handle after the last common n-gram | |
last_ngram_end = ngram_positions[-1][-1][1] | |
if last_ngram_end < len(words) - 1: | |
mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1)) | |
# Mask the chosen indices | |
for idx in mask_indices: | |
if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: | |
words[idx] = self.tokenizer.mask_token | |
return " ".join(words) | |
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False): | |
""" | |
Mask one word before the first common n-gram, one between two n-grams, | |
and one after the last common n-gram (highest entropy selection). | |
Args: | |
original_sentence (str): Original sentence | |
common_ngrams (dict): Common n-grams and their indices | |
Returns: | |
str: Masked sentence | |
""" | |
if remove_stopwords: | |
words = original_sentence.split() | |
words = [word for word in words if word not in self.stop_words] | |
else: | |
words = original_sentence.split() | |
entropy_scores = {} | |
for idx, word in enumerate(words): | |
if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: | |
continue # Skip words in common n-grams | |
masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:] | |
masked_sentence = " ".join(masked_sentence) | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
filtered_logits = logits[0, mask_token_index, :] | |
probs = torch.softmax(filtered_logits, dim=-1) | |
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0) | |
entropy_scores[idx] = entropy | |
mask_indices = [] | |
# Handle before the first common n-gram | |
if common_ngrams: | |
first_ngram_start = list(common_ngrams.values())[0][0][0] | |
candidates = [i for i in range(first_ngram_start) if i in entropy_scores] | |
if candidates: | |
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) | |
# Handle between common n-grams | |
ngram_positions = list(common_ngrams.values()) | |
for i in range(len(ngram_positions) - 1): | |
end_prev = ngram_positions[i][-1][1] | |
start_next = ngram_positions[i + 1][0][0] | |
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] | |
if candidates: | |
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) | |
# Handle after the last common n-gram | |
last_ngram_end = ngram_positions[-1][-1][1] | |
candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores] | |
if candidates: | |
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) | |
# Mask the chosen indices | |
for idx in mask_indices: | |
words[idx] = self.tokenizer.mask_token | |
return " ".join(words) | |
def calculate_mask_logits(self, masked_sentence): | |
""" | |
Calculate logits for masked tokens in the sentence using BERT. | |
Args: | |
masked_sentence (str): Sentence with [MASK] tokens | |
Returns: | |
dict: Masked token indices and their logits | |
""" | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} | |
return mask_logits | |
def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"): | |
""" | |
Process a list of sentences and calculate logits for masked tokens using the specified method. | |
Args: | |
original_sentences (list): List of original sentences | |
result_dict (dict): Common n-grams and their indices for each sentence | |
method (str): Masking method ("random" or "entropy") | |
Returns: | |
dict: Masked sentences and their logits for each sentence | |
""" | |
results = {} | |
for sentence, ngrams in result_dict.items(): | |
if method == "random": | |
masked_sentence = self.mask_sentence_random(sentence, ngrams) | |
elif method == "entropy": | |
masked_sentence = self.mask_sentence_entropy(sentence, ngrams) | |
else: | |
raise ValueError("Invalid method. Choose 'random' or 'entropy'.") | |
logits = self.calculate_mask_logits(masked_sentence) | |
results[sentence] = { | |
"masked_sentence": masked_sentence, | |
"mask_logits": logits | |
} | |
return results | |
# Example usage | |
if __name__ == "__main__": | |
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS | |
sentences = [ | |
"The quick brown fox jumps over the lazy dog.", | |
"A quick brown dog outpaces a lazy fox.", | |
"Quick brown animals leap over lazy obstacles." | |
] | |
result_dict = { | |
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, | |
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, | |
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} | |
} | |
# result_dict = { | |
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, | |
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, | |
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} | |
# } | |
processor = MaskingProcessor() | |
results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random") | |
results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy") | |
for sentence, output in results_random.items(): | |
print(f"Original Sentence (Random): {sentence}") | |
print(f"Masked Sentence (Random): {output['masked_sentence']}") | |
# print(f"Mask Logits (Random): {output['mask_logits']}") | |
for sentence, output in results_entropy.items(): | |
print(f"Original Sentence (Entropy): {sentence}") | |
print(f"Masked Sentence (Entropy): {output['masked_sentence']}") | |
# print(f"Mask Logits (Entropy): {output['mask_logits']}") | |
''' | |
result_dict = { | |
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, | |
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, | |
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} | |
} | |
''' |