from transformers import AutoTokenizer, TFAutoModelForSequenceClassification import tensorflow as tf from typing import List, Tuple from logger_config import config_logger logger = config_logger(__name__) class CrossEncoderReranker: """ Cross-Encoder Re-Ranker: Takes (query, candidate) pairs, outputs a single relevance score (one logit). """ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) # Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score. def rerank( self, query: str, candidates: List[str], max_length: int = 256 ) -> List[float]: """ Returns a list of re_scores, one for each candidate, indicating how relevant the candidate is to the query. """ # Build (query, candidate) pairs pair_texts = [(query, candidate) for candidate in candidates] # Tokenize the entire batch encodings = self.tokenizer( pair_texts, padding=True, truncation=True, max_length=max_length, return_tensors="tf" ) # Forward pass -> logits shape [batch_size, 1] outputs = self.model( input_ids=encodings["input_ids"], attention_mask=encodings["attention_mask"], token_type_ids=encodings.get("token_type_ids") ) logits = outputs.logits # Flatten to shape [batch_size] scores = tf.reshape(logits, [-1]).numpy() return scores.tolist()