File size: 3,442 Bytes
6fe7180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Tuple

import numpy as np
import pandas as pd
import argparse
from pathlib import Path

import torch
from sentence_transformers import SentenceTransformer

def xlogx(x):
    if x == 0:
        return 0
    else:
        return x * torch.log(x)

def parse_summaries(path : Path):

    # Load the data
    df = pd.read_csv(path)

    if 'id' not in df.columns:
        raise ValueError('id column not found in the summaries file')
    if 'text' not in df.columns:
        raise ValueError('text column not found in the summaries file')
    if 'summary' not in df.columns:
        raise ValueError('summary column not found in the summaries file')

    return df


def embed_text_and_summaries(df : pd.DataFrame, model : SentenceTransformer) -> Tuple[torch.Tensor, torch.Tensor]:

    text_embeddings = model.encode(df.text.tolist(), convert_to_tensor=True)
    summary_embeddings = model.encode(df.summary.tolist(), convert_to_tensor=True)

    return text_embeddings, summary_embeddings


def compute_dot_products(df : pd.DataFrame, text_embeddings : torch.Tensor, summary_embeddings : torch.Tensor):

    df = df.reset_index()
    df['index'] = df.index

    # group by id
    grouped = df.groupby('id')

    # for each id gather the id of the text and the summary
    ids_per_sample = grouped.index.apply(list).tolist()

    # compute the dot product between the text and the summary

    metrics = {'proba_of_success' : []}
    for text_ids in ids_per_sample:
        # shape (num_text, embedding_dim)
        text_embedding = text_embeddings[text_ids]
        summary_embedding = summary_embeddings[text_ids]

        # shape (num_text, num_text=num_summary)
        dot_product = torch.matmul(text_embedding, summary_embedding.T)

        # apply log softmax
        log_softmax = torch.nn.functional.log_softmax(dot_product, dim=0)

        # num_text
        log_proba_of_success = torch.diag(log_softmax).squeeze()
        entropy = torch.xlogy(log_proba_of_success, log_proba_of_success).sum(0).squeeze()

        metrics['proba_of_success'].extend(log_proba_of_success.tolist())
        # metrics['entropy'].extend(entropy.tolist())

    df['proba_of_success'] = metrics['proba_of_success']
    # df['entropy'] = metrics['entropy']

    return df

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--summaries', type=Path, required=True)
    parser.add_argument('--model', type=str, default='paraphrase-MiniLM-L6-v2')
    parser.add_argument('--output', type=Path, required=True)
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()
    return args

def main():

    args = parse_args()

    # load the model
    model = SentenceTransformer(args.model, device=args.device)

    # load the summaries
    df = parse_summaries(args.summaries)

    # embedd the text and the summary
    text_embeddings, summary_embeddings = embed_text_and_summaries(df, model)

    # compute the dot product between the text and the summary
    df = compute_dot_products(df, text_embeddings, summary_embeddings)

    # create the output directory
    args.output.mkdir(parents=True, exist_ok=True)

    path = args.output / f"{args.summaries.stem}.csv"

    # save the results
    df.to_csv(path, index=False)


if __name__ == '__main__':
    main()