Spaces:
Sleeping
Sleeping
from transformers import BertTokenizer, BertForMaskedLM | |
import torch | |
import random | |
from masking_methods import MaskingProcessor | |
from transformers import pipeline | |
class SamplingProcessorWithModel: | |
def __init__(self, model_name='bert-base-uncased'): | |
self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
self.model = BertForMaskedLM.from_pretrained(model_name) | |
self.model.eval() # Set the model to evaluation mode | |
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): | |
""" | |
Fills each mask in the masked sentence using the specified sampling technique. | |
Args: | |
masked_sentence (str): Sentence with [MASK] tokens. | |
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). | |
temperature (float): Temperature parameter for sampling methods. | |
Returns: | |
str: Sentence with the masks filled. | |
""" | |
input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt") | |
while self.tokenizer.mask_token_id in input_ids[0]: | |
# Find indices of all [MASK] tokens | |
mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
# Process the first [MASK] token in the sequence | |
mask_index = mask_indices[0].item() | |
# Get logits from the model | |
with torch.no_grad(): | |
outputs = self.model(input_ids) | |
logits = outputs.logits | |
# Extract logits for the [MASK] token | |
mask_logits = logits[0, mask_index] | |
if sampling_technique == "inverse_transform": | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
cumulative_probs = torch.cumsum(probs, dim=-1) | |
random_prob = random.random() | |
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
elif sampling_technique == "exponential_minimum": | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
exp_probs = torch.exp(-torch.log(probs)) | |
random_probs = torch.rand_like(exp_probs) | |
sampled_index = torch.argmax(random_probs * exp_probs).item() | |
elif sampling_technique == "temperature": | |
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
raise ValueError("The computed probabilities contain NaN or inf values.") | |
probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device)) | |
probs = probs / torch.sum(probs) | |
probs = probs.flatten() | |
if probs.size(0) > 1: | |
sampled_index = torch.multinomial(probs, 1).item() | |
else: | |
sampled_index = torch.argmax(probs).item() | |
elif sampling_technique == 'greedy': | |
sampled_index = torch.argmax(mask_logits).item() | |
else: | |
raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# Replace the first [MASK] with the selected token | |
input_ids[0, mask_index] = sampled_index | |
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): | |
""" | |
Fills each mask in the masked sentence using the specified sampling technique. | |
Args: | |
masked_sentence (str): Sentence with [MASK] tokens. | |
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). | |
temperature (float): Temperature parameter for sampling methods. | |
Returns: | |
str: Sentence with the masks filled. | |
""" | |
while '[MASK]' in masked_sentence: | |
# Get predictions for the first [MASK] | |
predictions = self.unmasker(masked_sentence) | |
# Ensure predictions is a list of dictionaries | |
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions): | |
raise ValueError("Unexpected structure in predictions from the pipeline.") | |
# Extract logits (scores) from the predictions | |
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32) | |
if sampling_technique == "inverse_transform": | |
probs = torch.softmax(logits / temperature, dim=-1) | |
cumulative_probs = torch.cumsum(probs, dim=-1) | |
random_prob = random.random() | |
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
elif sampling_technique == "exponential_minimum": | |
probs = torch.softmax(logits / temperature, dim=-1) | |
exp_probs = torch.exp(-torch.log(probs)) | |
random_probs = torch.rand_like(exp_probs) | |
sampled_index = torch.argmax(random_probs * exp_probs).item() | |
elif sampling_technique == "temperature": | |
logits = torch.clamp(logits, min=-1e8, max=1e8) | |
probs = torch.softmax(logits / temperature, dim=-1) | |
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
raise ValueError("The computed probabilities contain NaN or inf values.") | |
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device)) | |
probs = probs / torch.sum(probs) | |
probs = probs.flatten() | |
if probs.size(0) > 1: | |
sampled_index = torch.multinomial(probs, 1).item() | |
else: | |
sampled_index = torch.argmax(probs).item() | |
elif sampling_technique == 'greedy': | |
sampled_index = torch.argmax(logits).item() | |
else: | |
raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# Replace the first [MASK] with the selected word | |
sampled_token = predictions[sampled_index]['token_str'] | |
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1) | |
return masked_sentence | |
# Example usage | |
if __name__ == "__main__": | |
from transformers import BertTokenizer | |
# Define sentences and result_dict | |
sentences = [ | |
"The quick brown fox jumps over the lazy dog.", | |
"A quick brown dog outpaces a lazy fox.", | |
"Quick brown dog leaps over lazy the fox." | |
] | |
result_dict = { | |
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]}, | |
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}, | |
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]} | |
} | |
masking_processor = MaskingProcessor() | |
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) | |
# Use SamplingProcessor | |
sampling_processor = SamplingProcessorWithModel() | |
# Iterate through masking results to apply sampling | |
for sentence, result in masking_results.items(): | |
print(f"Original Sentence (Random): {sentence}") | |
print(f"Masked Sentence (Random): {result['masked_sentence']}") | |
masked_sentence = result["masked_sentence"] | |
# Apply different sampling techniques | |
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]: | |
print(f"Sampling Technique: {technique}") | |
filled_sentence = sampling_processor.fill_masked_sentence( | |
masked_sentence=masked_sentence, | |
sampling_technique=technique, | |
temperature=1.0 # Adjust temperature as needed | |
) | |
print(f"Filled Sentence: {filled_sentence}\n") | |
print('--------------------------------') | |
# from transformers import pipeline | |
# import torch | |
# import random | |
# from masking_methods import MaskingProcessor | |
# class SamplingProcessorWithPipeline: | |
# def __init__(self, model_name='bert-base-uncased'): | |
# self.unmasker = pipeline('fill-mask', model=model_name) | |
# self.tokenizer = self.unmasker.tokenizer | |
# def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): | |
# """ | |
# Fills each mask in the masked sentence using the specified sampling technique. | |
# Args: | |
# masked_sentence (str): Sentence with [MASK] tokens. | |
# sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). | |
# temperature (float): Temperature parameter for sampling methods. | |
# Returns: | |
# str: Sentence with the masks filled. | |
# """ | |
# while '[MASK]' in masked_sentence: | |
# # Get predictions for the first [MASK] | |
# predictions = self.unmasker(masked_sentence) | |
# print(f' predictions : {predictions}') | |
# print(f' type of predictions : {type(predictions)}') | |
# # Ensure predictions is a list of dictionaries for the first [MASK] | |
# if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions): | |
# raise ValueError("Unexpected structure in predictions from the pipeline.") | |
# # Extract logits (scores) from the predictions | |
# logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32) | |
# if sampling_technique == "inverse_transform": | |
# probs = torch.softmax(logits / temperature, dim=-1) | |
# cumulative_probs = torch.cumsum(probs, dim=-1) | |
# random_prob = random.random() | |
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
# elif sampling_technique == "exponential_minimum": | |
# probs = torch.softmax(logits / temperature, dim=-1) | |
# exp_probs = torch.exp(-torch.log(probs)) | |
# random_probs = torch.rand_like(exp_probs) | |
# sampled_index = torch.argmax(random_probs * exp_probs).item() | |
# elif sampling_technique == "temperature": | |
# logits = torch.clamp(logits, min=-1e8, max=1e8) | |
# probs = torch.softmax(logits / temperature, dim=-1) | |
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
# raise ValueError("The computed probabilities contain NaN or inf values.") | |
# probs = torch.max(probs, torch.tensor(1e-8, device=logits.device)) | |
# probs = probs / torch.sum(probs) | |
# probs = probs.flatten() | |
# if probs.size(0) > 1: | |
# sampled_index = torch.multinomial(probs, 1).item() | |
# else: | |
# sampled_index = torch.argmax(probs).item() | |
# elif sampling_technique == 'greedy': | |
# sampled_index = torch.argmax(logits).item() | |
# else: | |
# raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# # Replace the first [MASK] with the selected word | |
# sampled_token = predictions[sampled_index]['token_str'] | |
# masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1) | |
# return masked_sentence | |
# # Example usage | |
# if __name__ == "__main__": | |
# from transformers import BertTokenizer | |
# # Define sentences and result_dict | |
# sentences = [ | |
# "The quick brown fox jumps over the lazy dog.", | |
# "A quick brown dog outpaces a lazy fox.", | |
# "Quick brown animals leap over lazy obstacles." | |
# ] | |
# result_dict = { | |
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, | |
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, | |
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} | |
# } | |
# masking_processor = MaskingProcessor() | |
# masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) | |
# # Use SamplingProcessor | |
# sampling_processor = SamplingProcessorWithPipeline() | |
# # Iterate through masking results to apply sampling | |
# for sentence, result in masking_results.items(): | |
# print(f"Original Sentence (Random): {sentence}") | |
# print(f"Masked Sentence (Random): {result['masked_sentence']}") | |
# masked_sentence = result["masked_sentence"] | |
# # Apply different sampling techniques | |
# for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]: | |
# print(f"Sampling Technique: {technique}") | |
# filled_sentence = sampling_processor.fill_masked_sentence( | |
# masked_sentence=masked_sentence, | |
# sampling_technique=technique, | |
# temperature=1.0 # Adjust temperature as needed | |
# ) | |
# print(f"Filled Sentence: {filled_sentence}\n") | |
# print('--------------------------------') | |