File size: 9,405 Bytes
6fe7180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
from functools import cache
from typing import List
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
def kl_divergence(p, q):
"""
Compute the KL divergence between two distributions
"""
return torch.nan_to_num(p * (p / q).log(), nan=0.0).sum(-1)
def jensen_shannon_divergence(p, q):
"""
Compute the Jensen-Shannon divergence between two distributions
"""
m = 0.5 * (p + q)
return 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
class RSAReranking:
"""
Rerank a list of candidates according to the RSA model.
"""
def __init__(
self,
model,
tokenizer,
candidates: List[str],
source_texts: List[str],
batch_size: int = 32,
rationality: int = 1,
device="cuda",
):
"""
:param model: hf model used to compute the likelihoods (supposed to be a seq2seq model), is S0 in the RSA model
:param tokenizer:
:param candidates: list of candidates summaries
:param source_texts: list of source texts
:param batch_size: batch size used to compute the likelihoods (can be high since we don't need gradients and
it's a single forward pass)
:param rationality: rationality parameter of the RSA model
:param device: device used to compute the likelihoods
"""
self.model = model
self.device = device
self.model = model.to(self.device)
self.tokenizer = tokenizer
self.candidates = candidates
self.source_texts = source_texts
self.batch_size = batch_size
self.rationality = rationality
def compute_conditionned_likelihood(
self, x: List[str], y: List[str], mean: bool = True
) -> torch.Tensor:
"""
Compute the likelihood of y given x
:param x: list of source texts len(x) = batch_size
:param y: list of candidates summaries len(y) = batch_size
:param mean: average the likelihoods over the tokens of y or take the sum
:return: tensor of shape (batch_size) containing the likelihoods of y given x
"""
# Ensure x,y are pure Python lists of strings (not pandas.Series, np.ndarray, etc.)
x = [str(item) for item in list(x)]
y = [str(item) for item in list(y)]
assert len(x) == len(y), "x and y must have the same length"
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
batch_size = len(x)
x = self.tokenizer(
x,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
y = self.tokenizer(
y,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
# Move all tensors to the correct device
x = {k: v.to(self.device) for k, v in x.items()}
y = {k: v.to(self.device) for k, v in y.items()}
# Concatenate the two inputs
# Compute the likelihood of y given x
x_ids = x["input_ids"]
y_ids = y["input_ids"]
logits = self.model(
input_ids=x_ids,
decoder_input_ids=y_ids,
attention_mask=x["attention_mask"],
decoder_attention_mask=y["attention_mask"],
).logits
# Compute the likelihood of y given x
shifted_logits = logits[..., :-1, :].contiguous()
shifted_ids = y_ids[..., 1:].contiguous()
likelihood = -loss_fn(
shifted_logits.view(-1, shifted_logits.size(-1)), shifted_ids.view(-1)
)
likelihood = likelihood.view(batch_size, -1).sum(-1)
if mean:
likelihood /= (y_ids != self.tokenizer.pad_token_id).float().sum(-1)
return likelihood
def score(self, x: List[str], y: List[str], **kwargs):
return self.compute_conditionned_likelihood(x, y, **kwargs)
def likelihood_matrix(self) -> torch.Tensor:
"""
:return: likelihood matrix : (world_size, num_candidates), likelihood[i, j] is the likelihood of
candidate j being a summary for source text i.
"""
likelihood_matrix = torch.zeros(
(len(self.source_texts), len(self.candidates))
).to(self.device)
pairs = []
for i, source_text in enumerate(self.source_texts):
for j, candidate in enumerate(self.candidates):
pairs.append((i, j, source_text, candidate))
# split the pairs into batches
batches = [
pairs[i: i + self.batch_size]
for i in range(0, len(pairs), self.batch_size)
]
for batch in tqdm(batches):
# get the source texts and candidates
source_texts = [pair[2] for pair in batch]
candidates = [pair[3] for pair in batch]
# compute the likelihoods
with torch.no_grad():
likelihoods = self.score(
source_texts, candidates, mean=True
)
# fill the matrix
for k, (i, j, _, _) in enumerate(batch):
likelihood_matrix[i, j] = likelihoods[k].detach()
return likelihood_matrix
@cache
def S(self, t):
if t == 0:
return self.initial_speaker_probas
else:
listener = self.L(t - 1)
prod = listener * self.rationality # + self.initial_speaker_probas.sum(0, keepdim=True)
return torch.log_softmax(prod, dim=-1)
@cache
def L(self, t):
speaker = self.S(t)
return torch.log_softmax(speaker, dim=-2)
def mk_listener_dataframe(self, t):
self.initial_speaker_probas = self.likelihood_matrix()
initial_listener_probas = self.L(0)
# compute consensus
uniform_distribution_over_source_texts = torch.ones_like(
initial_listener_probas
) / len(self.source_texts)
initital_consensuality_score = (
torch.exp(initial_listener_probas)
* (
initial_listener_probas - torch.log(uniform_distribution_over_source_texts)
)
).sum(0).cpu().numpy()
initital_consensuality_score = pd.Series(initital_consensuality_score, index=self.candidates)
initial_listener_probas = initial_listener_probas.cpu().numpy()
initial_listener_probas = pd.DataFrame(initial_listener_probas)
initial_listener_probas.index = self.source_texts
initial_listener_probas.columns = self.candidates
initial_speaker_probas = self.S(0).cpu().numpy()
initial_speaker_probas = pd.DataFrame(initial_speaker_probas)
initial_speaker_probas.index = self.source_texts
initial_speaker_probas.columns = self.candidates
listener_df = pd.DataFrame(self.L(t).cpu().numpy())
consensuality_scores = (
torch.exp(self.L(t))
* (self.L(t) - torch.log(uniform_distribution_over_source_texts))
).sum(0).cpu().numpy()
consensuality_scores = pd.Series(consensuality_scores, index=self.candidates)
S = self.S(t).cpu().numpy()
speaker_df = pd.DataFrame(S)
# add the source texts and candidates as index
listener_df.index = self.source_texts
speaker_df.index = self.source_texts
listener_df.columns = self.candidates
speaker_df.columns = self.candidates
return listener_df, speaker_df, initial_listener_probas, initial_speaker_probas, initital_consensuality_score, consensuality_scores
def rerank(self, t=1):
"""
return the best summary (according to rsa) for each text
"""
(
listener_df,
speaker_df,
initial_listener_proba,
initial_speaker_proba,
initital_consensuality_score,
consensuality_scores,
) = self.mk_listener_dataframe(t=t)
best_rsa = speaker_df.idxmax(axis=1).values
best_base = initial_listener_proba.idxmax(axis=1).values
return (
best_rsa,
best_base,
speaker_df,
listener_df,
initial_listener_proba,
initial_speaker_proba,
initital_consensuality_score,
consensuality_scores,
)
class RSARerankingEmbedder(RSAReranking):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_embeddings(self, x: List[str], y: List[str], **kwargs):
model_kwargs = kwargs.get("model_kwargs")
# shape: (batch_size, embedding_dim)
x_embeddings = self.model.encode(x, **model_kwargs)
y_embeddings = self.model.encode(y, **model_kwargs)
# dot product between the embeddings : shape (batch_size)
dot_products = (x_embeddings * y_embeddings).sum(-1)
return dot_products
def score(self, x: List[str], y: List[str], **kwargs):
return self.compute_embeddings(x, y, **kwargs)
|