ReView / glimpse-ui /glimpse /rsasumm /rsa_reranker.py
Sina1138's picture
Super-squash branch 'main' using huggingface_hub
6fe7180
raw
history blame
9.41 kB
from functools import cache
from typing import List
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
def kl_divergence(p, q):
"""
Compute the KL divergence between two distributions
"""
return torch.nan_to_num(p * (p / q).log(), nan=0.0).sum(-1)
def jensen_shannon_divergence(p, q):
"""
Compute the Jensen-Shannon divergence between two distributions
"""
m = 0.5 * (p + q)
return 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
class RSAReranking:
"""
Rerank a list of candidates according to the RSA model.
"""
def __init__(
self,
model,
tokenizer,
candidates: List[str],
source_texts: List[str],
batch_size: int = 32,
rationality: int = 1,
device="cuda",
):
"""
:param model: hf model used to compute the likelihoods (supposed to be a seq2seq model), is S0 in the RSA model
:param tokenizer:
:param candidates: list of candidates summaries
:param source_texts: list of source texts
:param batch_size: batch size used to compute the likelihoods (can be high since we don't need gradients and
it's a single forward pass)
:param rationality: rationality parameter of the RSA model
:param device: device used to compute the likelihoods
"""
self.model = model
self.device = device
self.model = model.to(self.device)
self.tokenizer = tokenizer
self.candidates = candidates
self.source_texts = source_texts
self.batch_size = batch_size
self.rationality = rationality
def compute_conditionned_likelihood(
self, x: List[str], y: List[str], mean: bool = True
) -> torch.Tensor:
"""
Compute the likelihood of y given x
:param x: list of source texts len(x) = batch_size
:param y: list of candidates summaries len(y) = batch_size
:param mean: average the likelihoods over the tokens of y or take the sum
:return: tensor of shape (batch_size) containing the likelihoods of y given x
"""
# Ensure x,y are pure Python lists of strings (not pandas.Series, np.ndarray, etc.)
x = [str(item) for item in list(x)]
y = [str(item) for item in list(y)]
assert len(x) == len(y), "x and y must have the same length"
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
batch_size = len(x)
x = self.tokenizer(
x,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
y = self.tokenizer(
y,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
# Move all tensors to the correct device
x = {k: v.to(self.device) for k, v in x.items()}
y = {k: v.to(self.device) for k, v in y.items()}
# Concatenate the two inputs
# Compute the likelihood of y given x
x_ids = x["input_ids"]
y_ids = y["input_ids"]
logits = self.model(
input_ids=x_ids,
decoder_input_ids=y_ids,
attention_mask=x["attention_mask"],
decoder_attention_mask=y["attention_mask"],
).logits
# Compute the likelihood of y given x
shifted_logits = logits[..., :-1, :].contiguous()
shifted_ids = y_ids[..., 1:].contiguous()
likelihood = -loss_fn(
shifted_logits.view(-1, shifted_logits.size(-1)), shifted_ids.view(-1)
)
likelihood = likelihood.view(batch_size, -1).sum(-1)
if mean:
likelihood /= (y_ids != self.tokenizer.pad_token_id).float().sum(-1)
return likelihood
def score(self, x: List[str], y: List[str], **kwargs):
return self.compute_conditionned_likelihood(x, y, **kwargs)
def likelihood_matrix(self) -> torch.Tensor:
"""
:return: likelihood matrix : (world_size, num_candidates), likelihood[i, j] is the likelihood of
candidate j being a summary for source text i.
"""
likelihood_matrix = torch.zeros(
(len(self.source_texts), len(self.candidates))
).to(self.device)
pairs = []
for i, source_text in enumerate(self.source_texts):
for j, candidate in enumerate(self.candidates):
pairs.append((i, j, source_text, candidate))
# split the pairs into batches
batches = [
pairs[i: i + self.batch_size]
for i in range(0, len(pairs), self.batch_size)
]
for batch in tqdm(batches):
# get the source texts and candidates
source_texts = [pair[2] for pair in batch]
candidates = [pair[3] for pair in batch]
# compute the likelihoods
with torch.no_grad():
likelihoods = self.score(
source_texts, candidates, mean=True
)
# fill the matrix
for k, (i, j, _, _) in enumerate(batch):
likelihood_matrix[i, j] = likelihoods[k].detach()
return likelihood_matrix
@cache
def S(self, t):
if t == 0:
return self.initial_speaker_probas
else:
listener = self.L(t - 1)
prod = listener * self.rationality # + self.initial_speaker_probas.sum(0, keepdim=True)
return torch.log_softmax(prod, dim=-1)
@cache
def L(self, t):
speaker = self.S(t)
return torch.log_softmax(speaker, dim=-2)
def mk_listener_dataframe(self, t):
self.initial_speaker_probas = self.likelihood_matrix()
initial_listener_probas = self.L(0)
# compute consensus
uniform_distribution_over_source_texts = torch.ones_like(
initial_listener_probas
) / len(self.source_texts)
initital_consensuality_score = (
torch.exp(initial_listener_probas)
* (
initial_listener_probas - torch.log(uniform_distribution_over_source_texts)
)
).sum(0).cpu().numpy()
initital_consensuality_score = pd.Series(initital_consensuality_score, index=self.candidates)
initial_listener_probas = initial_listener_probas.cpu().numpy()
initial_listener_probas = pd.DataFrame(initial_listener_probas)
initial_listener_probas.index = self.source_texts
initial_listener_probas.columns = self.candidates
initial_speaker_probas = self.S(0).cpu().numpy()
initial_speaker_probas = pd.DataFrame(initial_speaker_probas)
initial_speaker_probas.index = self.source_texts
initial_speaker_probas.columns = self.candidates
listener_df = pd.DataFrame(self.L(t).cpu().numpy())
consensuality_scores = (
torch.exp(self.L(t))
* (self.L(t) - torch.log(uniform_distribution_over_source_texts))
).sum(0).cpu().numpy()
consensuality_scores = pd.Series(consensuality_scores, index=self.candidates)
S = self.S(t).cpu().numpy()
speaker_df = pd.DataFrame(S)
# add the source texts and candidates as index
listener_df.index = self.source_texts
speaker_df.index = self.source_texts
listener_df.columns = self.candidates
speaker_df.columns = self.candidates
return listener_df, speaker_df, initial_listener_probas, initial_speaker_probas, initital_consensuality_score, consensuality_scores
def rerank(self, t=1):
"""
return the best summary (according to rsa) for each text
"""
(
listener_df,
speaker_df,
initial_listener_proba,
initial_speaker_proba,
initital_consensuality_score,
consensuality_scores,
) = self.mk_listener_dataframe(t=t)
best_rsa = speaker_df.idxmax(axis=1).values
best_base = initial_listener_proba.idxmax(axis=1).values
return (
best_rsa,
best_base,
speaker_df,
listener_df,
initial_listener_proba,
initial_speaker_proba,
initital_consensuality_score,
consensuality_scores,
)
class RSARerankingEmbedder(RSAReranking):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_embeddings(self, x: List[str], y: List[str], **kwargs):
model_kwargs = kwargs.get("model_kwargs")
# shape: (batch_size, embedding_dim)
x_embeddings = self.model.encode(x, **model_kwargs)
y_embeddings = self.model.encode(y, **model_kwargs)
# dot product between the embeddings : shape (batch_size)
dot_products = (x_embeddings * y_embeddings).sum(-1)
return dot_products
def score(self, x: List[str], y: List[str], **kwargs):
return self.compute_embeddings(x, y, **kwargs)