ask-candid / ask_candid /retrieval /sparse_lexical.py
brainsqueeze's picture
Batching
2744d22 verified
from typing import List, Dict
from tqdm.auto import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch
class SpladeEncoder:
batch_size = 4
def __init__(self):
model_id = "naver/splade-v3"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForMaskedLM.from_pretrained(model_id)
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
@torch.no_grad()
def forward(self, texts: List[str]):
vectors = []
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
output = self.model(**tokens)
vec = torch.max(
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
dim=1
)[0].squeeze()
vectors.append(vec)
return torch.vstack(vectors)
def query_reranking(self, query: str, documents: List[str]):
vec = self.forward([query, *documents])
xQ = F.normalize(vec[:1], dim=-1, p=2.)
xD = F.normalize(vec[1:], dim=-1, p=2.)
return (xQ * xD).sum(dim=-1).cpu().tolist()
def token_expand(self, query: str) -> Dict[str, float]:
vec = self.forward([query]).squeeze()
cols = vec.nonzero().squeeze().cpu().tolist()
weights = vec[cols].cpu().tolist()
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))