File size: 691 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
class SentenceTransformersCrossEncoder:
"""Wrapper for sentence-transformers cross-encoder model.
"""
def __init__(
self, model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"
):
try:
from sentence_transformers.cross_encoder import CrossEncoder
except ImportError:
raise ModuleNotFoundError(
"You need to install sentence-transformers library to use SentenceTransformersCrossEncoder."
)
self.model = CrossEncoder(model_name_or_path)
def __call__(self, query: str, passage: list[str]) -> list[float]:
return self.model.predict([[query, p] for p in passage]).tolist()
|