|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
def xlogx(x):
|
|
if x == 0:
|
|
return 0
|
|
else:
|
|
return x * torch.log(x)
|
|
|
|
def parse_summaries(path : Path):
|
|
|
|
|
|
df = pd.read_csv(path)
|
|
|
|
if 'id' not in df.columns:
|
|
raise ValueError('id column not found in the summaries file')
|
|
if 'text' not in df.columns:
|
|
raise ValueError('text column not found in the summaries file')
|
|
if 'summary' not in df.columns:
|
|
raise ValueError('summary column not found in the summaries file')
|
|
|
|
return df
|
|
|
|
|
|
def embed_text_and_summaries(df : pd.DataFrame, model : SentenceTransformer) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
text_embeddings = model.encode(df.text.tolist(), convert_to_tensor=True)
|
|
summary_embeddings = model.encode(df.summary.tolist(), convert_to_tensor=True)
|
|
|
|
return text_embeddings, summary_embeddings
|
|
|
|
|
|
def compute_dot_products(df : pd.DataFrame, text_embeddings : torch.Tensor, summary_embeddings : torch.Tensor):
|
|
|
|
df = df.reset_index()
|
|
df['index'] = df.index
|
|
|
|
|
|
grouped = df.groupby('id')
|
|
|
|
|
|
ids_per_sample = grouped.index.apply(list).tolist()
|
|
|
|
|
|
|
|
metrics = {'proba_of_success' : []}
|
|
for text_ids in ids_per_sample:
|
|
|
|
text_embedding = text_embeddings[text_ids]
|
|
summary_embedding = summary_embeddings[text_ids]
|
|
|
|
|
|
dot_product = torch.matmul(text_embedding, summary_embedding.T)
|
|
|
|
|
|
log_softmax = torch.nn.functional.log_softmax(dot_product, dim=0)
|
|
|
|
|
|
log_proba_of_success = torch.diag(log_softmax).squeeze()
|
|
entropy = torch.xlogy(log_proba_of_success, log_proba_of_success).sum(0).squeeze()
|
|
|
|
metrics['proba_of_success'].extend(log_proba_of_success.tolist())
|
|
|
|
|
|
df['proba_of_success'] = metrics['proba_of_success']
|
|
|
|
|
|
return df
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--summaries', type=Path, required=True)
|
|
parser.add_argument('--model', type=str, default='paraphrase-MiniLM-L6-v2')
|
|
parser.add_argument('--output', type=Path, required=True)
|
|
parser.add_argument('--device', type=str, default='cuda')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
|
|
model = SentenceTransformer(args.model, device=args.device)
|
|
|
|
|
|
df = parse_summaries(args.summaries)
|
|
|
|
|
|
text_embeddings, summary_embeddings = embed_text_and_summaries(df, model)
|
|
|
|
|
|
df = compute_dot_products(df, text_embeddings, summary_embeddings)
|
|
|
|
|
|
args.output.mkdir(parents=True, exist_ok=True)
|
|
|
|
path = args.output / f"{args.summaries.stem}.csv"
|
|
|
|
|
|
df.to_csv(path, index=False)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |