peccavi / utils /masking_methods.py
PECCAVI-TEXT's picture
Update utils/masking_methods.py
539b7d4 verified
import random
import torch
import logging
import string
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk
from tqdm import tqdm
# Set logging to WARNING for a cleaner terminal.
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Ensure stopwords are downloaded
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
def clean_word(word):
"""More robust cleaning for consistent matching"""
# Remove possessive 's before other punctuation
if word.lower().endswith("'s"):
word = word[:-2]
return word.lower().strip().translate(str.maketrans('', '', string.punctuation))
class MaskingProcessor:
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.stop_words = set(stopwords.words('english'))
tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}")
def is_stopword(self, word):
"""Check if a word is a stopword, handling punctuation and case"""
return clean_word(word) in self.stop_words
def verify_and_correct_ngram_positions(self, sentence, common_ngrams):
"""Verify ngram positions match actual words in sentence and correct if needed."""
words = sentence.split()
corrected_ngrams = {}
for ngram, positions in common_ngrams.items():
corrected_positions = []
ngram_words = ngram.split()
# Convert ngram words to clean format for matching
clean_ngram_words = [clean_word(word) for word in ngram_words]
# Scan the sentence to find actual occurrences of the ngram
for i in range(len(words) - len(ngram_words) + 1):
is_match = True
for j, ngram_word in enumerate(clean_ngram_words):
if clean_word(words[i + j]) != ngram_word:
is_match = False
break
if is_match:
# Found a matching position, add it
corrected_positions.append((i, i + len(ngram_words) - 1))
if corrected_positions:
corrected_ngrams[ngram] = corrected_positions
else:
# Log the issue and perform a more flexible search
print(f"Warning: Could not find exact match for '{ngram}' in the sentence.")
print(f"Attempting flexible matching...")
# Try a more flexible approach by looking for individual words
for i in range(len(words)):
if clean_word(words[i]) == clean_ngram_words[0]:
# We found the first word of the ngram
if len(ngram_words) == 1 or (
i + len(ngram_words) <= len(words) and
all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words)))
):
corrected_positions.append((i, i + len(ngram_words) - 1))
if corrected_positions:
print(f"Found flexible matches for '{ngram}': {corrected_positions}")
corrected_ngrams[ngram] = corrected_positions
else:
# If still no match, keep original positions as fallback
print(f"No matches found for '{ngram}'. Keeping original positions.")
corrected_ngrams[ngram] = positions
# Log changes
if corrected_ngrams != common_ngrams:
print(f"Original ngram positions: {common_ngrams}")
print(f"Corrected ngram positions: {corrected_ngrams}")
return corrected_ngrams
def in_any_ngram(self, idx, ngram_positions):
"""Check if an original sentence index is part of any n-gram span"""
return any(start <= idx <= end for start, end in ngram_positions)
def create_fallback_mask(self, sentence, ngrams):
"""Create a fallback mask when normal strategies fail."""
try:
words = sentence.split()
if not words:
return None
# Find any non-stopword that isn't in an ngram
ngram_positions = []
for positions in ngrams.values():
for start, end in positions:
ngram_positions.append((start, end))
ngram_positions.sort()
# Find first eligible word
for idx, word in enumerate(words):
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
masked_words = words.copy()
masked_words[idx] = self.tokenizer.mask_token
tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'")
return " ".join(masked_words), [idx]
# If no eligible word found, just mask the first non-stop word
for idx, word in enumerate(words):
if not self.is_stopword(word):
masked_words = words.copy()
masked_words[idx] = self.tokenizer.mask_token
tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'")
return " ".join(masked_words), [idx]
# If still nothing, mask the first word
if words:
masked_words = words.copy()
masked_words[0] = self.tokenizer.mask_token
return " ".join(masked_words), [0]
return None
except Exception as e:
tqdm.write(f"[ERROR] Error creating fallback mask: {e}")
return None
def mask_sentence_random(self, sentence, common_ngrams):
"""Mask random non-stopwords that are not part of common ngrams with controlled positioning."""
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}")
original_words = sentence.split()
# Handle punctuation
has_punctuation = False
punctuation = ''
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
has_punctuation = True
punctuation = original_words[-1][-1]
original_words[-1] = original_words[-1][:-1]
if not original_words[-1]: # If the word was just punctuation
original_words.pop()
# Get flattened ngram positions
ngram_positions = []
for positions in common_ngrams.values():
for start, end in positions:
ngram_positions.append((start, end))
ngram_positions.sort()
# Find all candidate indices (non-stopwords not in ngrams)
candidate_indices = []
for idx, word in enumerate(original_words):
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
candidate_indices.append(idx)
# Debug print candidate words
print("Candidate words for masking:")
for idx in candidate_indices:
print(f" Position {idx}: '{original_words[idx]}'")
selected_indices = []
if ngram_positions:
# Before first ngram
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
if before_first:
num_to_select = min(1, len(before_first)) # Select 1 word
if num_to_select > 0:
selected = random.sample(before_first, num_to_select)
selected_indices.extend(selected)
# Between ngrams
for i in range(len(ngram_positions) - 1):
between = [idx for idx in candidate_indices
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
if between:
num_to_select = min(2, len(between)) # Select between 1-2 words
if num_to_select > 0:
selected = random.sample(between, num_to_select)
selected_indices.extend(selected)
# After last ngram
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
if after_last:
num_to_select = min(1, len(after_last)) # Select 1 word
if num_to_select > 0:
selected = random.sample(after_last, num_to_select)
selected_indices.extend(selected)
else:
# If no ngrams, pick up to 6 random candidates
if candidate_indices:
selected_indices = random.sample(candidate_indices,
min(6, len(candidate_indices)))
masked_words = original_words.copy()
for idx in selected_indices:
masked_words[idx] = self.tokenizer.mask_token
if has_punctuation:
masked_words.append(punctuation)
# Debug prints
print("Original sentence:", sentence)
print("Common ngrams:", common_ngrams)
print("Common ngram positions:", ngram_positions)
print("Candidate indices for masking:", candidate_indices)
print("Selected for masking:", selected_indices)
print("Masked sentence:", " ".join(masked_words))
return " ".join(masked_words), selected_indices
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
"""Mask specific non-stopwords based on their position relative to ngrams."""
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}")
random.seed(3) # Fixed seed for pseudorandom behavior
original_words = sentence.split()
# Handle punctuation
has_punctuation = False
punctuation = ''
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
has_punctuation = True
punctuation = original_words[-1][-1]
original_words[-1] = original_words[-1][:-1]
if not original_words[-1]: # If the word was just punctuation
original_words.pop()
# Get flattened ngram positions
ngram_positions = []
for positions in common_ngrams.values():
for start, end in positions:
ngram_positions.append((start, end))
ngram_positions.sort()
# Find all candidate indices (non-stopwords not in ngrams)
candidate_indices = []
for idx, word in enumerate(original_words):
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
candidate_indices.append(idx)
# Debug print candidate words
print("Candidate words for masking:")
for idx in candidate_indices:
print(f" Position {idx}: '{original_words[idx]}'")
# PSEUDORANDOM SPECIFIC LOGIC:
selected_indices = []
if ngram_positions:
# Before first ngram
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
if before_first:
num_to_select = min(1, len(before_first)) # Select 1 word
if num_to_select > 0:
selected = random.sample(before_first, num_to_select)
selected_indices.extend(selected)
# Between ngrams
for i in range(len(ngram_positions) - 1):
between = [idx for idx in candidate_indices
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
if between:
num_to_select = min(2, len(between)) # Select between 1-2 words
if num_to_select > 0:
selected = random.sample(between, num_to_select)
selected_indices.extend(selected)
# After last ngram
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
if after_last:
num_to_select = min(1, len(after_last)) # Select 1 word
if num_to_select > 0:
selected = random.sample(after_last, num_to_select)
selected_indices.extend(selected)
else:
# If no ngrams, pick up to 6 random candidates
if candidate_indices:
selected_indices = random.sample(candidate_indices,
min(6, len(candidate_indices)))
masked_words = original_words.copy()
for idx in selected_indices:
masked_words[idx] = self.tokenizer.mask_token
if has_punctuation:
masked_words.append(punctuation)
# Debug prints
print("Original sentence:", sentence)
print("Common ngrams:", common_ngrams)
print("Common ngram positions:", ngram_positions)
print("Candidate indices for masking:", candidate_indices)
print("Selected for masking:", selected_indices)
print("Masked sentence:", " ".join(masked_words))
return " ".join(masked_words), selected_indices
def mask_sentence_entropy(self, sentence, common_ngrams):
"""Mask words with highest entropy that are not part of common ngrams."""
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}")
original_words = sentence.split()
# Handle punctuation
has_punctuation = False
punctuation = ''
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
has_punctuation = True
punctuation = original_words[-1][-1]
original_words[-1] = original_words[-1][:-1]
if not original_words[-1]: # If the word was just punctuation
original_words.pop()
# Get flattened ngram positions
ngram_positions = []
for positions in common_ngrams.values():
for start, end in positions:
ngram_positions.append((start, end))
ngram_positions.sort()
# Find all candidate indices (non-stopwords not in ngrams)
candidate_indices = []
for idx, word in enumerate(original_words):
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
candidate_indices.append(idx)
# Debug print candidate words
print("Candidate words for masking:")
for idx in candidate_indices:
print(f" Position {idx}: '{original_words[idx]}'")
# ENTROPY SPECIFIC LOGIC:
# Calculate entropy for each candidate word
selected_indices = []
if candidate_indices:
# Organize candidates by position relative to ngrams
if ngram_positions:
# Group candidates by position
before_first = []
between_ngrams = {}
after_last = []
for idx in candidate_indices:
if idx < ngram_positions[0][0]:
before_first.append(idx)
elif idx > ngram_positions[-1][1]:
after_last.append(idx)
else:
# Find which ngram gap this belongs to
for i in range(len(ngram_positions) - 1):
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]:
if i not in between_ngrams:
between_ngrams[i] = []
between_ngrams[i].append(idx)
# Before first ngram: select 1-2 highest entropy words
if before_first:
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first]
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
num_to_select = min(1, len(entropies)) # Select 1 word
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
# For each gap between ngrams: select 1-2 highest entropy words
for group, indices in between_ngrams.items():
if indices:
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices]
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
num_to_select = min(2, len(entropies)) # Select between 1-2 words
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
# After last ngram: select 1-2 highest entropy words
if after_last:
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last]
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
num_to_select = min(1, len(entropies)) # Select 1 word
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
else:
# If no ngrams, calculate entropy for all candidates
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices]
# Sort by entropy (highest first)
entropies.sort(key=lambda x: x[1], reverse=True)
# Take top 6 or all if fewer
selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]]
masked_words = original_words.copy()
for idx in selected_indices:
masked_words[idx] = self.tokenizer.mask_token
if has_punctuation:
masked_words.append(punctuation)
# Debug prints
print("Original sentence:", sentence)
print("Common ngrams:", common_ngrams)
print("Common ngram positions:", ngram_positions)
print("Candidate indices for masking:", candidate_indices)
print("Selected for masking:", selected_indices)
print("Masked sentence:", " ".join(masked_words))
return " ".join(masked_words), selected_indices
def calculate_mask_logits(self, original_sentence, original_mask_indices):
"""Calculate logits for masked positions."""
logger.info(f"Calculating mask logits for sentence: {original_sentence}")
words = original_sentence.split()
mask_logits = {}
for idx in original_mask_indices:
masked_words = words.copy()
masked_words[idx] = self.tokenizer.mask_token
masked_sentence = " ".join(masked_words)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
mask_logits_tensor = logits[0, mask_token_index, :]
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1)
top_tokens = []
top_logits = []
seen_words = set()
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
if token.startswith('##'):
continue
word = self.tokenizer.convert_tokens_to_string([token]).strip()
if word and word not in seen_words:
seen_words.add(word)
top_tokens.append(word)
top_logits.append(logit.item())
if len(top_tokens) == 50:
break
mask_logits[idx] = {
"tokens": top_tokens,
"logits": top_logits
}
logger.info("Completed calculating mask logits.")
return mask_logits
def calculate_word_entropy(self, sentence, word_position):
"""Calculate entropy for a word at a specific position."""
logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}")
words = sentence.split()
masked_words = words.copy()
masked_words[word_position] = self.tokenizer.mask_token
masked_sentence = " ".join(masked_words)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
logger.info(f"Computed entropy: {entropy.item()}")
return entropy.item()
def process_sentences(self, sentences_list, common_grams, method="random"):
"""Process multiple sentences with the specified masking method."""
tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}")
results = {}
for sentence in tqdm(sentences_list, desc="Masking Sentences"):
try:
ngrams = common_grams.get(sentence, {})
if method == "random":
masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams)
elif method == "pseudorandom":
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams)
else: # entropy
masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams)
# Skip if no masks were applied
if not original_mask_indices:
tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...")
# Create a fallback masked sentence with at least one mask
fallback_result = self.create_fallback_mask(sentence, ngrams)
if fallback_result:
masked_sentence, original_mask_indices = fallback_result
tqdm.write(f"[INFO] Created fallback mask for sentence")
else:
tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence")
continue
logits = self.calculate_mask_logits(sentence, original_mask_indices)
results[sentence] = {
"masked_sentence": masked_sentence,
"mask_logits": logits
}
logger.info(f"Processed sentence: {sentence}")
except Exception as e:
tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}")
tqdm.write(f"Sentence: {sentence[:100]}...")
import traceback
tqdm.write(traceback.format_exc())
tqdm.write("[MaskingProcessor] Completed processing sentences.")
return results
@staticmethod
def identify_common_ngrams(sentences, entities):
"""Enhanced to handle possessive forms better"""
common_grams = {}
# Pre-process entities to handle variations
processed_entities = []
for entity in entities:
processed_entities.append(entity)
# Add possessive form if not already there
if not entity.endswith("'s") and not entity.endswith("s"):
processed_entities.append(f"{entity}'s")
for sentence in sentences:
words = sentence.split()
common_grams[sentence] = {}
# Look for each entity in the sentence
for entity in processed_entities:
entity_words = entity.split()
entity_len = len(entity_words)
# Convert entity words for matching
clean_entity_words = [clean_word(word) for word in entity_words]
# Find all occurrences
for i in range(len(words) - entity_len + 1):
is_match = True
for j, entity_word in enumerate(clean_entity_words):
if clean_word(words[i + j]) != entity_word:
is_match = False
break
if is_match:
# Use canonical form from entity list for consistency
base_entity = entity
if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities):
base_entity = entity[:-2]
if base_entity not in common_grams[sentence]:
common_grams[sentence][base_entity] = []
common_grams[sentence][base_entity].append((i, i + entity_len - 1))
return common_grams
if __name__ == "__main__":
#example test
# test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title."
# entities to preserve
# entities = ["Kevin De Bruyne", "Manchester City", "Premier League"]
# Identify common n-grams
common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities)
# Print detected n-grams
print(f"Detected common n-grams: {common_grams[test_sentence]}")
# Initialize the processor
processor = MaskingProcessor(
BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"),
BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
)
# Test all three masking methods
print("\nTesting Random Masking:")
masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence])
print("\nTesting Pseudorandom Masking:")
masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence])
print("\nTesting Entropy Masking:")
masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence])
# Print results
print("\nResults:")
print(f"Original: {test_sentence}")
print(f"Random Masked: {masked_random}")
print(f"Pseudorandom Masked: {masked_pseudorandom}")
print(f"Entropy Masked: {masked_entropy}")