File size: 2,498 Bytes
f7b283c 7a0020b f7b283c cc2577d f7b283c 7a0020b cc2577d 7a0020b cc2577d 7a0020b f7b283c 7a0020b f7b283c cc2577d 7a0020b cc2577d f7b283c cc2577d f7b283c c7c1b4e f7b283c cc2577d f7b283c cc2577d f7b283c 7a0020b cc2577d 7a0020b f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
from typing import List
from logger_config import config_logger
logger = config_logger(__name__)
class CrossEncoderReranker:
"""
Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1].
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
"""
Init the cross-encoder with a pretrained model.
Args:
model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification.
"""
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
logger.info("Cross encoder model loaded successfully.")
def rerank(
self,
query: str,
candidates: List[str],
max_length: int = 256
) -> List[float]:
"""
Compute relevance scores for each candidate w.r.t. query.
Args:
query: User's query text.
candidates: List of candidate response texts.
max_length: Max token length for each (query, candidate) pair.
Returns:
A list of float scores [0...1]. One per candidate, indicating model's predicted relevance.
"""
# Build (query, candidate) pairs, then tokenize
pair_texts = [(query, candidate) for candidate in candidates]
encodings = self.tokenizer(
pair_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="tf",
verbose=False
)
# Forward pass, logits shape [batch_size, 1]
# Then convert logits to [0...1] range with sigmoid
# Note: token_type_ids are optional. .get() avoids KeyError
outputs = self.model(
input_ids=encodings["input_ids"],
attention_mask=encodings["attention_mask"],
token_type_ids=encodings.get("token_type_ids")
)
logits = outputs.logits # shape [batch_size, 1]
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
# Flatten to 1D NumPy array, ensure float type
scores = tf.reshape(scores, [-1])
scores = scores.numpy().astype(float)
return scores.tolist() |