File size: 3,442 Bytes
6fe7180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from typing import Tuple
import numpy as np
import pandas as pd
import argparse
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer
def xlogx(x):
if x == 0:
return 0
else:
return x * torch.log(x)
def parse_summaries(path : Path):
# Load the data
df = pd.read_csv(path)
if 'id' not in df.columns:
raise ValueError('id column not found in the summaries file')
if 'text' not in df.columns:
raise ValueError('text column not found in the summaries file')
if 'summary' not in df.columns:
raise ValueError('summary column not found in the summaries file')
return df
def embed_text_and_summaries(df : pd.DataFrame, model : SentenceTransformer) -> Tuple[torch.Tensor, torch.Tensor]:
text_embeddings = model.encode(df.text.tolist(), convert_to_tensor=True)
summary_embeddings = model.encode(df.summary.tolist(), convert_to_tensor=True)
return text_embeddings, summary_embeddings
def compute_dot_products(df : pd.DataFrame, text_embeddings : torch.Tensor, summary_embeddings : torch.Tensor):
df = df.reset_index()
df['index'] = df.index
# group by id
grouped = df.groupby('id')
# for each id gather the id of the text and the summary
ids_per_sample = grouped.index.apply(list).tolist()
# compute the dot product between the text and the summary
metrics = {'proba_of_success' : []}
for text_ids in ids_per_sample:
# shape (num_text, embedding_dim)
text_embedding = text_embeddings[text_ids]
summary_embedding = summary_embeddings[text_ids]
# shape (num_text, num_text=num_summary)
dot_product = torch.matmul(text_embedding, summary_embedding.T)
# apply log softmax
log_softmax = torch.nn.functional.log_softmax(dot_product, dim=0)
# num_text
log_proba_of_success = torch.diag(log_softmax).squeeze()
entropy = torch.xlogy(log_proba_of_success, log_proba_of_success).sum(0).squeeze()
metrics['proba_of_success'].extend(log_proba_of_success.tolist())
# metrics['entropy'].extend(entropy.tolist())
df['proba_of_success'] = metrics['proba_of_success']
# df['entropy'] = metrics['entropy']
return df
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--summaries', type=Path, required=True)
parser.add_argument('--model', type=str, default='paraphrase-MiniLM-L6-v2')
parser.add_argument('--output', type=Path, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
return args
def main():
args = parse_args()
# load the model
model = SentenceTransformer(args.model, device=args.device)
# load the summaries
df = parse_summaries(args.summaries)
# embedd the text and the summary
text_embeddings, summary_embeddings = embed_text_and_summaries(df, model)
# compute the dot product between the text and the summary
df = compute_dot_products(df, text_embeddings, summary_embeddings)
# create the output directory
args.output.mkdir(parents=True, exist_ok=True)
path = args.output / f"{args.summaries.stem}.csv"
# save the results
df.to_csv(path, index=False)
if __name__ == '__main__':
main() |