from transformers import AutoTokenizer, TFAutoModelForSequenceClassification import tensorflow as tf from typing import List import numpy as np from logger_config import config_logger logger = config_logger(__name__) class CrossEncoderReranker: """ Cross-Encoder Re-Ranker that takes (query, candidate) pairs, outputs a single relevance score in [0,1]. """ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): """ Initialize the cross-encoder with a pretrained model. Args: model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification. """ logger.info(f"Initializing CrossEncoderReranker with {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) logger.info("Cross encoder model loaded successfully.") def rerank( self, query: str, candidates: List[str], max_length: int = 256 ) -> List[float]: """ Compute relevance scores for each candidate w.r.t. the query. Args: query: User's query text. candidates: List of candidate response texts. max_length: Max token length for each (query, candidate) pair. Returns: A list of float scores in [0,1], one per candidate, indicating model's predicted relevance. """ # 1) Build (query, candidate) pairs pair_texts = [(query, candidate) for candidate in candidates] # 2) Tokenize the entire batch encodings = self.tokenizer( pair_texts, padding=True, truncation=True, max_length=max_length, return_tensors="tf" ) # 3) Forward pass -> logits shape [batch_size, 1] outputs = self.model( input_ids=encodings["input_ids"], attention_mask=encodings["attention_mask"], token_type_ids=encodings.get("token_type_ids") # Some models need token_type_ids ) logits = outputs.logits # shape [batch_size, 1] # 4) Convert logits -> [0,1] range via sigmoid # If the cross-encoder is a single-logit regression to [0,1], # this is a typical interpretation. scores = tf.nn.sigmoid(logits) # shape [batch_size, 1] # 5) Flatten to a 1D NumPy array of floats scores = tf.reshape(scores, [-1]) scores = scores.numpy().astype(float) # logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}") # logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}") return scores.tolist()