csc525_retrieval_based_chatbot / cross_encoder_reranker.py
JoeArmani
chat refinements
c7c1b4e
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List
from logger_config import config_logger
logger = config_logger(__name__)
class CrossEncoderReranker:
"""
Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1].
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
"""
Init the cross-encoder with a pretrained model.
Args:
model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification.
"""
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
logger.info("Cross encoder model loaded successfully.")
def rerank(
self,
query: str,
candidates: List[str],
max_length: int = 256
) -> List[float]:
"""
Compute relevance scores for each candidate w.r.t. query.
Args:
query: User's query text.
candidates: List of candidate response texts.
max_length: Max token length for each (query, candidate) pair.
Returns:
A list of float scores [0...1]. One per candidate, indicating model's predicted relevance.
"""
# Build (query, candidate) pairs, then tokenize
pair_texts = [(query, candidate) for candidate in candidates]
encodings = self.tokenizer(
pair_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="tf",
verbose=False
)
# Forward pass, logits shape [batch_size, 1]
# Then convert logits to [0...1] range with sigmoid
# Note: token_type_ids are optional. .get() avoids KeyError
outputs = self.model(
input_ids=encodings["input_ids"],
attention_mask=encodings["attention_mask"],
token_type_ids=encodings.get("token_type_ids")
)
logits = outputs.logits # shape [batch_size, 1]
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
# Flatten to 1D NumPy array, ensure float type
scores = tf.reshape(scores, [-1])
scores = scores.numpy().astype(float)
return scores.tolist()