File size: 1,728 Bytes
f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List, Tuple
from logger_config import config_logger
logger = config_logger(__name__)
class CrossEncoderReranker:
"""
Cross-Encoder Re-Ranker: Takes (query, candidate) pairs,
outputs a single relevance score (one logit).
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
# Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score.
def rerank(
self,
query: str,
candidates: List[str],
max_length: int = 256
) -> List[float]:
"""
Returns a list of re_scores, one for each candidate, indicating
how relevant the candidate is to the query.
"""
# Build (query, candidate) pairs
pair_texts = [(query, candidate) for candidate in candidates]
# Tokenize the entire batch
encodings = self.tokenizer(
pair_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="tf"
)
# Forward pass -> logits shape [batch_size, 1]
outputs = self.model(
input_ids=encodings["input_ids"],
attention_mask=encodings["attention_mask"],
token_type_ids=encodings.get("token_type_ids")
)
logits = outputs.logits
# Flatten to shape [batch_size]
scores = tf.reshape(logits, [-1]).numpy()
return scores.tolist() |