csc525_retrieval_based_chatbot / cross_encoder_reranker.py
JoeArmani
updates - new iteration with type token
7a0020b
raw
history blame
2.83 kB
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List
import numpy as np
from logger_config import config_logger
logger = config_logger(__name__)
class CrossEncoderReranker:
"""
Cross-Encoder Re-Ranker that takes (query, candidate) pairs,
outputs a single relevance score in [0,1].
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
"""
Initialize the cross-encoder with a pretrained model.
Args:
model_name: Name of a HF cross-encoder model. Must be
compatible with TFAutoModelForSequenceClassification.
"""
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
logger.info("Cross encoder model loaded successfully.")
def rerank(
self,
query: str,
candidates: List[str],
max_length: int = 256
) -> List[float]:
"""
Compute relevance scores for each candidate w.r.t. the query.
Args:
query: User's query text.
candidates: List of candidate response texts.
max_length: Max token length for each (query, candidate) pair.
Returns:
A list of float scores in [0,1], one per candidate,
indicating model's predicted relevance.
"""
# 1) Build (query, candidate) pairs
pair_texts = [(query, candidate) for candidate in candidates]
# 2) Tokenize the entire batch
encodings = self.tokenizer(
pair_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="tf"
)
# 3) Forward pass -> logits shape [batch_size, 1]
outputs = self.model(
input_ids=encodings["input_ids"],
attention_mask=encodings["attention_mask"],
token_type_ids=encodings.get("token_type_ids") # Some models need token_type_ids
)
logits = outputs.logits # shape [batch_size, 1]
# 4) Convert logits -> [0,1] range via sigmoid
# If the cross-encoder is a single-logit regression to [0,1],
# this is a typical interpretation.
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
# 5) Flatten to a 1D NumPy array of floats
scores = tf.reshape(scores, [-1])
scores = scores.numpy().astype(float)
# logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}")
# logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}")
return scores.tolist()