csc525_retrieval_based_chatbot / cross_encoder_reranker.py
JoeArmani
summarization, reranker, environment setup, and response quality checker
f7b283c
raw
history blame
1.73 kB
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List, Tuple
from logger_config import config_logger
logger = config_logger(__name__)
class CrossEncoderReranker:
"""
Cross-Encoder Re-Ranker: Takes (query, candidate) pairs,
outputs a single relevance score (one logit).
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
# Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score.
def rerank(
self,
query: str,
candidates: List[str],
max_length: int = 256
) -> List[float]:
"""
Returns a list of re_scores, one for each candidate, indicating
how relevant the candidate is to the query.
"""
# Build (query, candidate) pairs
pair_texts = [(query, candidate) for candidate in candidates]
# Tokenize the entire batch
encodings = self.tokenizer(
pair_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="tf"
)
# Forward pass -> logits shape [batch_size, 1]
outputs = self.model(
input_ids=encodings["input_ids"],
attention_mask=encodings["attention_mask"],
token_type_ids=encodings.get("token_type_ids")
)
logits = outputs.logits
# Flatten to shape [batch_size]
scores = tf.reshape(logits, [-1]).numpy()
return scores.tolist()