import random import torch import logging import string from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk from tqdm import tqdm # Set logging to WARNING for a cleaner terminal. logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') def clean_word(word): """More robust cleaning for consistent matching""" # Remove possessive 's before other punctuation if word.lower().endswith("'s"): word = word[:-2] return word.lower().strip().translate(str.maketrans('', '', string.punctuation)) class MaskingProcessor: def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.stop_words = set(stopwords.words('english')) tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}") def is_stopword(self, word): """Check if a word is a stopword, handling punctuation and case""" return clean_word(word) in self.stop_words def verify_and_correct_ngram_positions(self, sentence, common_ngrams): """Verify ngram positions match actual words in sentence and correct if needed.""" words = sentence.split() corrected_ngrams = {} for ngram, positions in common_ngrams.items(): corrected_positions = [] ngram_words = ngram.split() # Convert ngram words to clean format for matching clean_ngram_words = [clean_word(word) for word in ngram_words] # Scan the sentence to find actual occurrences of the ngram for i in range(len(words) - len(ngram_words) + 1): is_match = True for j, ngram_word in enumerate(clean_ngram_words): if clean_word(words[i + j]) != ngram_word: is_match = False break if is_match: # Found a matching position, add it corrected_positions.append((i, i + len(ngram_words) - 1)) if corrected_positions: corrected_ngrams[ngram] = corrected_positions else: # Log the issue and perform a more flexible search print(f"Warning: Could not find exact match for '{ngram}' in the sentence.") print(f"Attempting flexible matching...") # Try a more flexible approach by looking for individual words for i in range(len(words)): if clean_word(words[i]) == clean_ngram_words[0]: # We found the first word of the ngram if len(ngram_words) == 1 or ( i + len(ngram_words) <= len(words) and all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words))) ): corrected_positions.append((i, i + len(ngram_words) - 1)) if corrected_positions: print(f"Found flexible matches for '{ngram}': {corrected_positions}") corrected_ngrams[ngram] = corrected_positions else: # If still no match, keep original positions as fallback print(f"No matches found for '{ngram}'. Keeping original positions.") corrected_ngrams[ngram] = positions # Log changes if corrected_ngrams != common_ngrams: print(f"Original ngram positions: {common_ngrams}") print(f"Corrected ngram positions: {corrected_ngrams}") return corrected_ngrams def in_any_ngram(self, idx, ngram_positions): """Check if an original sentence index is part of any n-gram span""" return any(start <= idx <= end for start, end in ngram_positions) def create_fallback_mask(self, sentence, ngrams): """Create a fallback mask when normal strategies fail.""" try: words = sentence.split() if not words: return None # Find any non-stopword that isn't in an ngram ngram_positions = [] for positions in ngrams.values(): for start, end in positions: ngram_positions.append((start, end)) ngram_positions.sort() # Find first eligible word for idx, word in enumerate(words): if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): masked_words = words.copy() masked_words[idx] = self.tokenizer.mask_token tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'") return " ".join(masked_words), [idx] # If no eligible word found, just mask the first non-stop word for idx, word in enumerate(words): if not self.is_stopword(word): masked_words = words.copy() masked_words[idx] = self.tokenizer.mask_token tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'") return " ".join(masked_words), [idx] # If still nothing, mask the first word if words: masked_words = words.copy() masked_words[0] = self.tokenizer.mask_token return " ".join(masked_words), [0] return None except Exception as e: tqdm.write(f"[ERROR] Error creating fallback mask: {e}") return None def mask_sentence_random(self, sentence, common_ngrams): """Mask random non-stopwords that are not part of common ngrams with controlled positioning.""" common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}") original_words = sentence.split() # Handle punctuation has_punctuation = False punctuation = '' if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: has_punctuation = True punctuation = original_words[-1][-1] original_words[-1] = original_words[-1][:-1] if not original_words[-1]: # If the word was just punctuation original_words.pop() # Get flattened ngram positions ngram_positions = [] for positions in common_ngrams.values(): for start, end in positions: ngram_positions.append((start, end)) ngram_positions.sort() # Find all candidate indices (non-stopwords not in ngrams) candidate_indices = [] for idx, word in enumerate(original_words): if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): candidate_indices.append(idx) # Debug print candidate words print("Candidate words for masking:") for idx in candidate_indices: print(f" Position {idx}: '{original_words[idx]}'") selected_indices = [] if ngram_positions: # Before first ngram before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] if before_first: num_to_select = min(1, len(before_first)) # Select 1 word if num_to_select > 0: selected = random.sample(before_first, num_to_select) selected_indices.extend(selected) # Between ngrams for i in range(len(ngram_positions) - 1): between = [idx for idx in candidate_indices if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] if between: num_to_select = min(2, len(between)) # Select between 1-2 words if num_to_select > 0: selected = random.sample(between, num_to_select) selected_indices.extend(selected) # After last ngram after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] if after_last: num_to_select = min(1, len(after_last)) # Select 1 word if num_to_select > 0: selected = random.sample(after_last, num_to_select) selected_indices.extend(selected) else: # If no ngrams, pick up to 6 random candidates if candidate_indices: selected_indices = random.sample(candidate_indices, min(6, len(candidate_indices))) masked_words = original_words.copy() for idx in selected_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) # Debug prints print("Original sentence:", sentence) print("Common ngrams:", common_ngrams) print("Common ngram positions:", ngram_positions) print("Candidate indices for masking:", candidate_indices) print("Selected for masking:", selected_indices) print("Masked sentence:", " ".join(masked_words)) return " ".join(masked_words), selected_indices def mask_sentence_pseudorandom(self, sentence, common_ngrams): """Mask specific non-stopwords based on their position relative to ngrams.""" common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}") random.seed(3) # Fixed seed for pseudorandom behavior original_words = sentence.split() # Handle punctuation has_punctuation = False punctuation = '' if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: has_punctuation = True punctuation = original_words[-1][-1] original_words[-1] = original_words[-1][:-1] if not original_words[-1]: # If the word was just punctuation original_words.pop() # Get flattened ngram positions ngram_positions = [] for positions in common_ngrams.values(): for start, end in positions: ngram_positions.append((start, end)) ngram_positions.sort() # Find all candidate indices (non-stopwords not in ngrams) candidate_indices = [] for idx, word in enumerate(original_words): if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): candidate_indices.append(idx) # Debug print candidate words print("Candidate words for masking:") for idx in candidate_indices: print(f" Position {idx}: '{original_words[idx]}'") # PSEUDORANDOM SPECIFIC LOGIC: selected_indices = [] if ngram_positions: # Before first ngram before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] if before_first: num_to_select = min(1, len(before_first)) # Select 1 word if num_to_select > 0: selected = random.sample(before_first, num_to_select) selected_indices.extend(selected) # Between ngrams for i in range(len(ngram_positions) - 1): between = [idx for idx in candidate_indices if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] if between: num_to_select = min(2, len(between)) # Select between 1-2 words if num_to_select > 0: selected = random.sample(between, num_to_select) selected_indices.extend(selected) # After last ngram after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] if after_last: num_to_select = min(1, len(after_last)) # Select 1 word if num_to_select > 0: selected = random.sample(after_last, num_to_select) selected_indices.extend(selected) else: # If no ngrams, pick up to 6 random candidates if candidate_indices: selected_indices = random.sample(candidate_indices, min(6, len(candidate_indices))) masked_words = original_words.copy() for idx in selected_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) # Debug prints print("Original sentence:", sentence) print("Common ngrams:", common_ngrams) print("Common ngram positions:", ngram_positions) print("Candidate indices for masking:", candidate_indices) print("Selected for masking:", selected_indices) print("Masked sentence:", " ".join(masked_words)) return " ".join(masked_words), selected_indices def mask_sentence_entropy(self, sentence, common_ngrams): """Mask words with highest entropy that are not part of common ngrams.""" common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}") original_words = sentence.split() # Handle punctuation has_punctuation = False punctuation = '' if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: has_punctuation = True punctuation = original_words[-1][-1] original_words[-1] = original_words[-1][:-1] if not original_words[-1]: # If the word was just punctuation original_words.pop() # Get flattened ngram positions ngram_positions = [] for positions in common_ngrams.values(): for start, end in positions: ngram_positions.append((start, end)) ngram_positions.sort() # Find all candidate indices (non-stopwords not in ngrams) candidate_indices = [] for idx, word in enumerate(original_words): if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): candidate_indices.append(idx) # Debug print candidate words print("Candidate words for masking:") for idx in candidate_indices: print(f" Position {idx}: '{original_words[idx]}'") # ENTROPY SPECIFIC LOGIC: # Calculate entropy for each candidate word selected_indices = [] if candidate_indices: # Organize candidates by position relative to ngrams if ngram_positions: # Group candidates by position before_first = [] between_ngrams = {} after_last = [] for idx in candidate_indices: if idx < ngram_positions[0][0]: before_first.append(idx) elif idx > ngram_positions[-1][1]: after_last.append(idx) else: # Find which ngram gap this belongs to for i in range(len(ngram_positions) - 1): if ngram_positions[i][1] < idx < ngram_positions[i+1][0]: if i not in between_ngrams: between_ngrams[i] = [] between_ngrams[i].append(idx) # Before first ngram: select 1-2 highest entropy words if before_first: entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first] entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) num_to_select = min(1, len(entropies)) # Select 1 word selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) # For each gap between ngrams: select 1-2 highest entropy words for group, indices in between_ngrams.items(): if indices: entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices] entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) num_to_select = min(2, len(entropies)) # Select between 1-2 words selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) # After last ngram: select 1-2 highest entropy words if after_last: entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last] entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) num_to_select = min(1, len(entropies)) # Select 1 word selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) else: # If no ngrams, calculate entropy for all candidates entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices] # Sort by entropy (highest first) entropies.sort(key=lambda x: x[1], reverse=True) # Take top 6 or all if fewer selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]] masked_words = original_words.copy() for idx in selected_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) # Debug prints print("Original sentence:", sentence) print("Common ngrams:", common_ngrams) print("Common ngram positions:", ngram_positions) print("Candidate indices for masking:", candidate_indices) print("Selected for masking:", selected_indices) print("Masked sentence:", " ".join(masked_words)) return " ".join(masked_words), selected_indices def calculate_mask_logits(self, original_sentence, original_mask_indices): """Calculate logits for masked positions.""" logger.info(f"Calculating mask logits for sentence: {original_sentence}") words = original_sentence.split() mask_logits = {} for idx in original_mask_indices: masked_words = words.copy() masked_words[idx] = self.tokenizer.mask_token masked_sentence = " ".join(masked_words) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits mask_logits_tensor = logits[0, mask_token_index, :] top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) top_tokens = [] top_logits = [] seen_words = set() for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): token = self.tokenizer.convert_ids_to_tokens(token_id.item()) if token.startswith('##'): continue word = self.tokenizer.convert_tokens_to_string([token]).strip() if word and word not in seen_words: seen_words.add(word) top_tokens.append(word) top_logits.append(logit.item()) if len(top_tokens) == 50: break mask_logits[idx] = { "tokens": top_tokens, "logits": top_logits } logger.info("Completed calculating mask logits.") return mask_logits def calculate_word_entropy(self, sentence, word_position): """Calculate entropy for a word at a specific position.""" logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}") words = sentence.split() masked_words = words.copy() masked_words[word_position] = self.tokenizer.mask_token masked_sentence = " ".join(masked_words) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-9)) logger.info(f"Computed entropy: {entropy.item()}") return entropy.item() def process_sentences(self, sentences_list, common_grams, method="random"): """Process multiple sentences with the specified masking method.""" tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}") results = {} for sentence in tqdm(sentences_list, desc="Masking Sentences"): try: ngrams = common_grams.get(sentence, {}) if method == "random": masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams) elif method == "pseudorandom": masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams) else: # entropy masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams) # Skip if no masks were applied if not original_mask_indices: tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...") # Create a fallback masked sentence with at least one mask fallback_result = self.create_fallback_mask(sentence, ngrams) if fallback_result: masked_sentence, original_mask_indices = fallback_result tqdm.write(f"[INFO] Created fallback mask for sentence") else: tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence") continue logits = self.calculate_mask_logits(sentence, original_mask_indices) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } logger.info(f"Processed sentence: {sentence}") except Exception as e: tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}") tqdm.write(f"Sentence: {sentence[:100]}...") import traceback tqdm.write(traceback.format_exc()) tqdm.write("[MaskingProcessor] Completed processing sentences.") return results @staticmethod def identify_common_ngrams(sentences, entities): """Enhanced to handle possessive forms better""" common_grams = {} # Pre-process entities to handle variations processed_entities = [] for entity in entities: processed_entities.append(entity) # Add possessive form if not already there if not entity.endswith("'s") and not entity.endswith("s"): processed_entities.append(f"{entity}'s") for sentence in sentences: words = sentence.split() common_grams[sentence] = {} # Look for each entity in the sentence for entity in processed_entities: entity_words = entity.split() entity_len = len(entity_words) # Convert entity words for matching clean_entity_words = [clean_word(word) for word in entity_words] # Find all occurrences for i in range(len(words) - entity_len + 1): is_match = True for j, entity_word in enumerate(clean_entity_words): if clean_word(words[i + j]) != entity_word: is_match = False break if is_match: # Use canonical form from entity list for consistency base_entity = entity if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities): base_entity = entity[:-2] if base_entity not in common_grams[sentence]: common_grams[sentence][base_entity] = [] common_grams[sentence][base_entity].append((i, i + entity_len - 1)) return common_grams if __name__ == "__main__": #example test # test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title." # entities to preserve # entities = ["Kevin De Bruyne", "Manchester City", "Premier League"] # Identify common n-grams common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities) # Print detected n-grams print(f"Detected common n-grams: {common_grams[test_sentence]}") # Initialize the processor processor = MaskingProcessor( BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"), BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") ) # Test all three masking methods print("\nTesting Random Masking:") masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence]) print("\nTesting Pseudorandom Masking:") masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence]) print("\nTesting Entropy Masking:") masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence]) # Print results print("\nResults:") print(f"Original: {test_sentence}") print(f"Random Masked: {masked_random}") print(f"Pseudorandom Masked: {masked_pseudorandom}") print(f"Entropy Masked: {masked_entropy}")