Spaces:
Sleeping
Sleeping
import random | |
import torch | |
import logging | |
import string | |
from transformers import BertTokenizer, BertForMaskedLM | |
from nltk.corpus import stopwords | |
import nltk | |
from tqdm import tqdm | |
# Set logging to WARNING for a cleaner terminal. | |
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Ensure stopwords are downloaded | |
try: | |
nltk.data.find('corpora/stopwords') | |
except LookupError: | |
nltk.download('stopwords') | |
def clean_word(word): | |
"""More robust cleaning for consistent matching""" | |
# Remove possessive 's before other punctuation | |
if word.lower().endswith("'s"): | |
word = word[:-2] | |
return word.lower().strip().translate(str.maketrans('', '', string.punctuation)) | |
class MaskingProcessor: | |
def __init__(self, tokenizer, model): | |
self.tokenizer = tokenizer | |
self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.stop_words = set(stopwords.words('english')) | |
tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}") | |
def is_stopword(self, word): | |
"""Check if a word is a stopword, handling punctuation and case""" | |
return clean_word(word) in self.stop_words | |
def verify_and_correct_ngram_positions(self, sentence, common_ngrams): | |
"""Verify ngram positions match actual words in sentence and correct if needed.""" | |
words = sentence.split() | |
corrected_ngrams = {} | |
for ngram, positions in common_ngrams.items(): | |
corrected_positions = [] | |
ngram_words = ngram.split() | |
# Convert ngram words to clean format for matching | |
clean_ngram_words = [clean_word(word) for word in ngram_words] | |
# Scan the sentence to find actual occurrences of the ngram | |
for i in range(len(words) - len(ngram_words) + 1): | |
is_match = True | |
for j, ngram_word in enumerate(clean_ngram_words): | |
if clean_word(words[i + j]) != ngram_word: | |
is_match = False | |
break | |
if is_match: | |
# Found a matching position, add it | |
corrected_positions.append((i, i + len(ngram_words) - 1)) | |
if corrected_positions: | |
corrected_ngrams[ngram] = corrected_positions | |
else: | |
# Log the issue and perform a more flexible search | |
print(f"Warning: Could not find exact match for '{ngram}' in the sentence.") | |
print(f"Attempting flexible matching...") | |
# Try a more flexible approach by looking for individual words | |
for i in range(len(words)): | |
if clean_word(words[i]) == clean_ngram_words[0]: | |
# We found the first word of the ngram | |
if len(ngram_words) == 1 or ( | |
i + len(ngram_words) <= len(words) and | |
all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words))) | |
): | |
corrected_positions.append((i, i + len(ngram_words) - 1)) | |
if corrected_positions: | |
print(f"Found flexible matches for '{ngram}': {corrected_positions}") | |
corrected_ngrams[ngram] = corrected_positions | |
else: | |
# If still no match, keep original positions as fallback | |
print(f"No matches found for '{ngram}'. Keeping original positions.") | |
corrected_ngrams[ngram] = positions | |
# Log changes | |
if corrected_ngrams != common_ngrams: | |
print(f"Original ngram positions: {common_ngrams}") | |
print(f"Corrected ngram positions: {corrected_ngrams}") | |
return corrected_ngrams | |
def in_any_ngram(self, idx, ngram_positions): | |
"""Check if an original sentence index is part of any n-gram span""" | |
return any(start <= idx <= end for start, end in ngram_positions) | |
def create_fallback_mask(self, sentence, ngrams): | |
"""Create a fallback mask when normal strategies fail.""" | |
try: | |
words = sentence.split() | |
if not words: | |
return None | |
# Find any non-stopword that isn't in an ngram | |
ngram_positions = [] | |
for positions in ngrams.values(): | |
for start, end in positions: | |
ngram_positions.append((start, end)) | |
ngram_positions.sort() | |
# Find first eligible word | |
for idx, word in enumerate(words): | |
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
masked_words = words.copy() | |
masked_words[idx] = self.tokenizer.mask_token | |
tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'") | |
return " ".join(masked_words), [idx] | |
# If no eligible word found, just mask the first non-stop word | |
for idx, word in enumerate(words): | |
if not self.is_stopword(word): | |
masked_words = words.copy() | |
masked_words[idx] = self.tokenizer.mask_token | |
tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'") | |
return " ".join(masked_words), [idx] | |
# If still nothing, mask the first word | |
if words: | |
masked_words = words.copy() | |
masked_words[0] = self.tokenizer.mask_token | |
return " ".join(masked_words), [0] | |
return None | |
except Exception as e: | |
tqdm.write(f"[ERROR] Error creating fallback mask: {e}") | |
return None | |
def mask_sentence_random(self, sentence, common_ngrams): | |
"""Mask random non-stopwords that are not part of common ngrams with controlled positioning.""" | |
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}") | |
original_words = sentence.split() | |
# Handle punctuation | |
has_punctuation = False | |
punctuation = '' | |
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
has_punctuation = True | |
punctuation = original_words[-1][-1] | |
original_words[-1] = original_words[-1][:-1] | |
if not original_words[-1]: # If the word was just punctuation | |
original_words.pop() | |
# Get flattened ngram positions | |
ngram_positions = [] | |
for positions in common_ngrams.values(): | |
for start, end in positions: | |
ngram_positions.append((start, end)) | |
ngram_positions.sort() | |
# Find all candidate indices (non-stopwords not in ngrams) | |
candidate_indices = [] | |
for idx, word in enumerate(original_words): | |
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
candidate_indices.append(idx) | |
# Debug print candidate words | |
print("Candidate words for masking:") | |
for idx in candidate_indices: | |
print(f" Position {idx}: '{original_words[idx]}'") | |
selected_indices = [] | |
if ngram_positions: | |
# Before first ngram | |
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] | |
if before_first: | |
num_to_select = min(1, len(before_first)) # Select 1 word | |
if num_to_select > 0: | |
selected = random.sample(before_first, num_to_select) | |
selected_indices.extend(selected) | |
# Between ngrams | |
for i in range(len(ngram_positions) - 1): | |
between = [idx for idx in candidate_indices | |
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] | |
if between: | |
num_to_select = min(2, len(between)) # Select between 1-2 words | |
if num_to_select > 0: | |
selected = random.sample(between, num_to_select) | |
selected_indices.extend(selected) | |
# After last ngram | |
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] | |
if after_last: | |
num_to_select = min(1, len(after_last)) # Select 1 word | |
if num_to_select > 0: | |
selected = random.sample(after_last, num_to_select) | |
selected_indices.extend(selected) | |
else: | |
# If no ngrams, pick up to 6 random candidates | |
if candidate_indices: | |
selected_indices = random.sample(candidate_indices, | |
min(6, len(candidate_indices))) | |
masked_words = original_words.copy() | |
for idx in selected_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
if has_punctuation: | |
masked_words.append(punctuation) | |
# Debug prints | |
print("Original sentence:", sentence) | |
print("Common ngrams:", common_ngrams) | |
print("Common ngram positions:", ngram_positions) | |
print("Candidate indices for masking:", candidate_indices) | |
print("Selected for masking:", selected_indices) | |
print("Masked sentence:", " ".join(masked_words)) | |
return " ".join(masked_words), selected_indices | |
def mask_sentence_pseudorandom(self, sentence, common_ngrams): | |
"""Mask specific non-stopwords based on their position relative to ngrams.""" | |
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}") | |
random.seed(3) # Fixed seed for pseudorandom behavior | |
original_words = sentence.split() | |
# Handle punctuation | |
has_punctuation = False | |
punctuation = '' | |
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
has_punctuation = True | |
punctuation = original_words[-1][-1] | |
original_words[-1] = original_words[-1][:-1] | |
if not original_words[-1]: # If the word was just punctuation | |
original_words.pop() | |
# Get flattened ngram positions | |
ngram_positions = [] | |
for positions in common_ngrams.values(): | |
for start, end in positions: | |
ngram_positions.append((start, end)) | |
ngram_positions.sort() | |
# Find all candidate indices (non-stopwords not in ngrams) | |
candidate_indices = [] | |
for idx, word in enumerate(original_words): | |
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
candidate_indices.append(idx) | |
# Debug print candidate words | |
print("Candidate words for masking:") | |
for idx in candidate_indices: | |
print(f" Position {idx}: '{original_words[idx]}'") | |
# PSEUDORANDOM SPECIFIC LOGIC: | |
selected_indices = [] | |
if ngram_positions: | |
# Before first ngram | |
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] | |
if before_first: | |
num_to_select = min(1, len(before_first)) # Select 1 word | |
if num_to_select > 0: | |
selected = random.sample(before_first, num_to_select) | |
selected_indices.extend(selected) | |
# Between ngrams | |
for i in range(len(ngram_positions) - 1): | |
between = [idx for idx in candidate_indices | |
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] | |
if between: | |
num_to_select = min(2, len(between)) # Select between 1-2 words | |
if num_to_select > 0: | |
selected = random.sample(between, num_to_select) | |
selected_indices.extend(selected) | |
# After last ngram | |
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] | |
if after_last: | |
num_to_select = min(1, len(after_last)) # Select 1 word | |
if num_to_select > 0: | |
selected = random.sample(after_last, num_to_select) | |
selected_indices.extend(selected) | |
else: | |
# If no ngrams, pick up to 6 random candidates | |
if candidate_indices: | |
selected_indices = random.sample(candidate_indices, | |
min(6, len(candidate_indices))) | |
masked_words = original_words.copy() | |
for idx in selected_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
if has_punctuation: | |
masked_words.append(punctuation) | |
# Debug prints | |
print("Original sentence:", sentence) | |
print("Common ngrams:", common_ngrams) | |
print("Common ngram positions:", ngram_positions) | |
print("Candidate indices for masking:", candidate_indices) | |
print("Selected for masking:", selected_indices) | |
print("Masked sentence:", " ".join(masked_words)) | |
return " ".join(masked_words), selected_indices | |
def mask_sentence_entropy(self, sentence, common_ngrams): | |
"""Mask words with highest entropy that are not part of common ngrams.""" | |
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}") | |
original_words = sentence.split() | |
# Handle punctuation | |
has_punctuation = False | |
punctuation = '' | |
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
has_punctuation = True | |
punctuation = original_words[-1][-1] | |
original_words[-1] = original_words[-1][:-1] | |
if not original_words[-1]: # If the word was just punctuation | |
original_words.pop() | |
# Get flattened ngram positions | |
ngram_positions = [] | |
for positions in common_ngrams.values(): | |
for start, end in positions: | |
ngram_positions.append((start, end)) | |
ngram_positions.sort() | |
# Find all candidate indices (non-stopwords not in ngrams) | |
candidate_indices = [] | |
for idx, word in enumerate(original_words): | |
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
candidate_indices.append(idx) | |
# Debug print candidate words | |
print("Candidate words for masking:") | |
for idx in candidate_indices: | |
print(f" Position {idx}: '{original_words[idx]}'") | |
# ENTROPY SPECIFIC LOGIC: | |
# Calculate entropy for each candidate word | |
selected_indices = [] | |
if candidate_indices: | |
# Organize candidates by position relative to ngrams | |
if ngram_positions: | |
# Group candidates by position | |
before_first = [] | |
between_ngrams = {} | |
after_last = [] | |
for idx in candidate_indices: | |
if idx < ngram_positions[0][0]: | |
before_first.append(idx) | |
elif idx > ngram_positions[-1][1]: | |
after_last.append(idx) | |
else: | |
# Find which ngram gap this belongs to | |
for i in range(len(ngram_positions) - 1): | |
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]: | |
if i not in between_ngrams: | |
between_ngrams[i] = [] | |
between_ngrams[i].append(idx) | |
# Before first ngram: select 1-2 highest entropy words | |
if before_first: | |
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first] | |
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
num_to_select = min(1, len(entropies)) # Select 1 word | |
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
# For each gap between ngrams: select 1-2 highest entropy words | |
for group, indices in between_ngrams.items(): | |
if indices: | |
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices] | |
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
num_to_select = min(2, len(entropies)) # Select between 1-2 words | |
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
# After last ngram: select 1-2 highest entropy words | |
if after_last: | |
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last] | |
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
num_to_select = min(1, len(entropies)) # Select 1 word | |
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
else: | |
# If no ngrams, calculate entropy for all candidates | |
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices] | |
# Sort by entropy (highest first) | |
entropies.sort(key=lambda x: x[1], reverse=True) | |
# Take top 6 or all if fewer | |
selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]] | |
masked_words = original_words.copy() | |
for idx in selected_indices: | |
masked_words[idx] = self.tokenizer.mask_token | |
if has_punctuation: | |
masked_words.append(punctuation) | |
# Debug prints | |
print("Original sentence:", sentence) | |
print("Common ngrams:", common_ngrams) | |
print("Common ngram positions:", ngram_positions) | |
print("Candidate indices for masking:", candidate_indices) | |
print("Selected for masking:", selected_indices) | |
print("Masked sentence:", " ".join(masked_words)) | |
return " ".join(masked_words), selected_indices | |
def calculate_mask_logits(self, original_sentence, original_mask_indices): | |
"""Calculate logits for masked positions.""" | |
logger.info(f"Calculating mask logits for sentence: {original_sentence}") | |
words = original_sentence.split() | |
mask_logits = {} | |
for idx in original_mask_indices: | |
masked_words = words.copy() | |
masked_words[idx] = self.tokenizer.mask_token | |
masked_sentence = " ".join(masked_words) | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
mask_logits_tensor = logits[0, mask_token_index, :] | |
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) | |
top_tokens = [] | |
top_logits = [] | |
seen_words = set() | |
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): | |
token = self.tokenizer.convert_ids_to_tokens(token_id.item()) | |
if token.startswith('##'): | |
continue | |
word = self.tokenizer.convert_tokens_to_string([token]).strip() | |
if word and word not in seen_words: | |
seen_words.add(word) | |
top_tokens.append(word) | |
top_logits.append(logit.item()) | |
if len(top_tokens) == 50: | |
break | |
mask_logits[idx] = { | |
"tokens": top_tokens, | |
"logits": top_logits | |
} | |
logger.info("Completed calculating mask logits.") | |
return mask_logits | |
def calculate_word_entropy(self, sentence, word_position): | |
"""Calculate entropy for a word at a specific position.""" | |
logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}") | |
words = sentence.split() | |
masked_words = words.copy() | |
masked_words[word_position] = self.tokenizer.mask_token | |
masked_sentence = " ".join(masked_words) | |
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) | |
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) | |
entropy = -torch.sum(probs * torch.log(probs + 1e-9)) | |
logger.info(f"Computed entropy: {entropy.item()}") | |
return entropy.item() | |
def process_sentences(self, sentences_list, common_grams, method="random"): | |
"""Process multiple sentences with the specified masking method.""" | |
tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}") | |
results = {} | |
for sentence in tqdm(sentences_list, desc="Masking Sentences"): | |
try: | |
ngrams = common_grams.get(sentence, {}) | |
if method == "random": | |
masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams) | |
elif method == "pseudorandom": | |
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams) | |
else: # entropy | |
masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams) | |
# Skip if no masks were applied | |
if not original_mask_indices: | |
tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...") | |
# Create a fallback masked sentence with at least one mask | |
fallback_result = self.create_fallback_mask(sentence, ngrams) | |
if fallback_result: | |
masked_sentence, original_mask_indices = fallback_result | |
tqdm.write(f"[INFO] Created fallback mask for sentence") | |
else: | |
tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence") | |
continue | |
logits = self.calculate_mask_logits(sentence, original_mask_indices) | |
results[sentence] = { | |
"masked_sentence": masked_sentence, | |
"mask_logits": logits | |
} | |
logger.info(f"Processed sentence: {sentence}") | |
except Exception as e: | |
tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}") | |
tqdm.write(f"Sentence: {sentence[:100]}...") | |
import traceback | |
tqdm.write(traceback.format_exc()) | |
tqdm.write("[MaskingProcessor] Completed processing sentences.") | |
return results | |
def identify_common_ngrams(sentences, entities): | |
"""Enhanced to handle possessive forms better""" | |
common_grams = {} | |
# Pre-process entities to handle variations | |
processed_entities = [] | |
for entity in entities: | |
processed_entities.append(entity) | |
# Add possessive form if not already there | |
if not entity.endswith("'s") and not entity.endswith("s"): | |
processed_entities.append(f"{entity}'s") | |
for sentence in sentences: | |
words = sentence.split() | |
common_grams[sentence] = {} | |
# Look for each entity in the sentence | |
for entity in processed_entities: | |
entity_words = entity.split() | |
entity_len = len(entity_words) | |
# Convert entity words for matching | |
clean_entity_words = [clean_word(word) for word in entity_words] | |
# Find all occurrences | |
for i in range(len(words) - entity_len + 1): | |
is_match = True | |
for j, entity_word in enumerate(clean_entity_words): | |
if clean_word(words[i + j]) != entity_word: | |
is_match = False | |
break | |
if is_match: | |
# Use canonical form from entity list for consistency | |
base_entity = entity | |
if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities): | |
base_entity = entity[:-2] | |
if base_entity not in common_grams[sentence]: | |
common_grams[sentence][base_entity] = [] | |
common_grams[sentence][base_entity].append((i, i + entity_len - 1)) | |
return common_grams | |
if __name__ == "__main__": | |
#example test | |
# test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title." | |
# entities to preserve | |
# entities = ["Kevin De Bruyne", "Manchester City", "Premier League"] | |
# Identify common n-grams | |
common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities) | |
# Print detected n-grams | |
print(f"Detected common n-grams: {common_grams[test_sentence]}") | |
# Initialize the processor | |
processor = MaskingProcessor( | |
BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"), | |
BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") | |
) | |
# Test all three masking methods | |
print("\nTesting Random Masking:") | |
masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence]) | |
print("\nTesting Pseudorandom Masking:") | |
masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence]) | |
print("\nTesting Entropy Masking:") | |
masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence]) | |
# Print results | |
print("\nResults:") | |
print(f"Original: {test_sentence}") | |
print(f"Random Masked: {masked_random}") | |
print(f"Pseudorandom Masked: {masked_pseudorandom}") | |
print(f"Entropy Masked: {masked_entropy}") |