muPPIt / muppit /guidance_eval /amazon_polarity_eval.py
AlienChen's picture
Upload 139 files
65bd8af verified
import collections
import json
import os
import hydra
import lightning as L
import omegaconf
import pandas as pd
import rdkit
import rich.syntax
import rich.tree
import spacy
import torch
import transformers
# from evaluate import load
from nltk.util import ngrams
from tqdm.auto import tqdm
import dataloader
import diffusion
import eval_utils
rdkit.rdBase.DisableLog('rdApp.error')
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
'if_then_else',
lambda condition, x, y: x if condition else y
)
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style,
guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
def compute_diversity(sentences):
# compute diversity
ngram_range = [2, 3, 4]
tokenizer = spacy.load("en_core_web_sm").tokenizer
token_list = []
for sentence in sentences:
token_list.append(
[str(token) for token in tokenizer(sentence)])
ngram_sets = {}
ngram_counts = collections.defaultdict(int)
n_gram_repetition = {}
for n in ngram_range:
ngram_sets[n] = set()
for tokens in token_list:
ngram_sets[n].update(ngrams(tokens, n))
ngram_counts[n] += len(list(ngrams(tokens, n)))
n_gram_repetition[f"{n}gram_repetition"] = (
1 - len(ngram_sets[n]) / ngram_counts[n])
diversity = 1
for val in n_gram_repetition.values():
diversity *= (1 - val)
return diversity
def compute_sentiment_classifier_score(sentences, eval_model_name_or_path):
tokenizer = transformers.AutoTokenizer.from_pretrained(eval_model_name_or_path)
eval_model = transformers.AutoModelForSequenceClassification.from_pretrained(
eval_model_name_or_path).to('cuda')
eval_model.eval()
total_pos = 0
total_neg = 0
pbar = tqdm(sentences, desc='Classifier eval')
for sen in pbar:
# Tokenize the input text
inputs = tokenizer(
sen,
return_tensors="pt",
truncation=True,
padding=True).to('cuda')
# Get the model predictions
with torch.no_grad():
outputs = eval_model(**inputs)
# Convert logits to probabilities
probs = torch.nn.functional.softmax(
outputs.logits, dim=-1)
# Get the predicted class
predicted_class = torch.argmax(probs, dim=1).item()
if predicted_class == 1:
total_pos += 1
else:
total_neg += 1
pbar.set_postfix(accuracy=total_pos / (total_pos + total_neg))
return total_pos / (total_pos + total_neg)
# def compute_mauve(config, tokenizer, sentences):
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# # compute mauve
# torch.cuda.empty_cache()
# mauve = load("mauve")
# human_references = []
#
# valid_loader = dataloader.get_dataloaders(
# config, tokenizer, valid_seed=config.seed)
#
# # construct reference
# for batch_id in range(config.sampling.num_sample_batches):
# batch = next(iter(valid_loader))
# input_ids = batch['input_ids']
# for i in range(config.sampling.batch_size):
# idx = (
# input_ids[i] == tokenizer.eos_token_id).nonzero(
# as_tuple=True)
# if idx[0].numel() > 0:
# idx = idx[0][0].item()
# input_ids[i, (idx + 1):] = 0
# human_references.extend(
# tokenizer.batch_decode(
# input_ids, skip_special_tokens=True))
#
# assert len(sentences) == len(human_references)
#
# results = mauve.compute(predictions=sentences,
# references=human_references,
# featurize_model_name=config.data.mauve_model,
# max_text_length=256, device_id=0)
# return results.mauve
@hydra.main(version_base=None, config_path='../configs',
config_name='config')
def main(config: omegaconf.DictConfig) -> None:
# Reproducibility
L.seed_everything(config.seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
_print_config(config, resolve=True)
print(f"Checkpoint: {config.eval.checkpoint_path}")
tokenizer = dataloader.get_tokenizer(config)
pretrained = diffusion.Diffusion.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config, logger=False)
pretrained.eval()
result_dicts = []
samples = []
for _ in tqdm(
range(config.sampling.num_sample_batches),
desc='Gen. batches', leave=False):
sample = pretrained.sample()
samples.extend(
pretrained.tokenizer.batch_decode(sample))
samples = [
s.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '').replace('[MASK]', '').strip()
for s in samples
]
del pretrained # free up space for eval
diversity_score = compute_diversity(samples)
classifier_accuracy = compute_sentiment_classifier_score(
samples, eval_model_name_or_path=config.eval.classifier_model_name_or_path)
generative_ppl = eval_utils.compute_generative_ppl(
samples,
eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path,
gen_ppl_eval_batch_size=8,
max_length=config.model.length)
result_dicts.append({
'Seed': config.seed,
'T': config.sampling.steps,
'Num Samples': config.sampling.batch_size * config.sampling.num_sample_batches,
'Diversity': diversity_score,
'Accuracy': classifier_accuracy,
'Gen. PPL': generative_ppl,
} | {k.capitalize(): v for k, v in config.guidance.items()})
print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()]))
print(f"\tDiversity: {diversity_score:0.3f} ",
f"Accuracy: {classifier_accuracy:0.3f} ",
f"Gen. PPL: {generative_ppl:0.3f}")
print(f"Generated {len(samples)} sentences.")
with open(config.eval.generated_samples_path, 'w') as f:
json.dump(
{
'generated_seqs': samples,
},
f, indent=4) # type: ignore
results_df = pd.DataFrame.from_records(result_dicts)
results_df.to_csv(config.eval.results_csv_path)
if __name__ == '__main__':
main()