|
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification |
|
import tensorflow as tf |
|
from typing import List |
|
|
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
class CrossEncoderReranker: |
|
""" |
|
Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1]. |
|
""" |
|
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"): |
|
""" |
|
Init the cross-encoder with a pretrained model. |
|
Args: |
|
model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification. |
|
""" |
|
logger.info(f"Initializing CrossEncoderReranker with {model_name}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name) |
|
logger.info("Cross encoder model loaded successfully.") |
|
|
|
def rerank( |
|
self, |
|
query: str, |
|
candidates: List[str], |
|
max_length: int = 256 |
|
) -> List[float]: |
|
""" |
|
Compute relevance scores for each candidate w.r.t. query. |
|
Args: |
|
query: User's query text. |
|
candidates: List of candidate response texts. |
|
max_length: Max token length for each (query, candidate) pair. |
|
Returns: |
|
A list of float scores [0...1]. One per candidate, indicating model's predicted relevance. |
|
""" |
|
|
|
pair_texts = [(query, candidate) for candidate in candidates] |
|
encodings = self.tokenizer( |
|
pair_texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors="tf", |
|
verbose=False |
|
) |
|
|
|
|
|
|
|
|
|
outputs = self.model( |
|
input_ids=encodings["input_ids"], |
|
attention_mask=encodings["attention_mask"], |
|
token_type_ids=encodings.get("token_type_ids") |
|
) |
|
|
|
logits = outputs.logits |
|
scores = tf.nn.sigmoid(logits) |
|
|
|
|
|
scores = tf.reshape(scores, [-1]) |
|
scores = scores.numpy().astype(float) |
|
|
|
return scores.tolist() |