File size: 1,803 Bytes
f86d7f2
bea5044
2744d22
 
bea5044
2744d22
f86d7f2
bea5044
 
 
 
2744d22
bea5044
 
 
 
 
 
 
 
 
f86d7f2
2744d22
 
 
 
 
 
 
 
 
 
f86d7f2
 
 
 
 
 
 
 
2744d22
bea5044
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import List, Dict

from tqdm.auto import tqdm

from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch


class SpladeEncoder:
    batch_size = 4

    def __init__(self):
        model_id = "naver/splade-v3"

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForMaskedLM.from_pretrained(model_id)
        self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}

    @torch.no_grad()
    def forward(self, texts: List[str]):
        vectors = []
        for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
            tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
            output = self.model(**tokens)
            vec = torch.max(
                torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
                dim=1
            )[0].squeeze()
            vectors.append(vec)
        return torch.vstack(vectors)

    def query_reranking(self, query: str, documents: List[str]):
        vec = self.forward([query, *documents])
        xQ = F.normalize(vec[:1], dim=-1, p=2.)
        xD = F.normalize(vec[1:], dim=-1, p=2.)
        return (xQ * xD).sum(dim=-1).cpu().tolist()

    def token_expand(self, query: str) -> Dict[str, float]:
        vec = self.forward([query]).squeeze()
        cols = vec.nonzero().squeeze().cpu().tolist()
        weights = vec[cols].cpu().tolist()

        sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
        return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))