from transformers import AutoTokenizer, TFAutoModelForSequenceClassification | |
import tensorflow as tf | |
from typing import List, Tuple | |
from logger_config import config_logger | |
logger = config_logger(__name__) | |
class CrossEncoderReranker: | |
""" | |
Cross-Encoder Re-Ranker: Takes (query, candidate) pairs, | |
outputs a single relevance score (one logit). | |
""" | |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) | |
# Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score. | |
def rerank( | |
self, | |
query: str, | |
candidates: List[str], | |
max_length: int = 256 | |
) -> List[float]: | |
""" | |
Returns a list of re_scores, one for each candidate, indicating | |
how relevant the candidate is to the query. | |
""" | |
# Build (query, candidate) pairs | |
pair_texts = [(query, candidate) for candidate in candidates] | |
# Tokenize the entire batch | |
encodings = self.tokenizer( | |
pair_texts, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors="tf" | |
) | |
# Forward pass -> logits shape [batch_size, 1] | |
outputs = self.model( | |
input_ids=encodings["input_ids"], | |
attention_mask=encodings["attention_mask"], | |
token_type_ids=encodings.get("token_type_ids") | |
) | |
logits = outputs.logits | |
# Flatten to shape [batch_size] | |
scores = tf.reshape(logits, [-1]).numpy() | |
return scores.tolist() |