class SentenceTransformersCrossEncoder: """Wrapper for sentence-transformers cross-encoder model. """ def __init__( self, model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-12-v2" ): try: from sentence_transformers.cross_encoder import CrossEncoder except ImportError: raise ModuleNotFoundError( "You need to install sentence-transformers library to use SentenceTransformersCrossEncoder." ) self.model = CrossEncoder(model_name_or_path) def __call__(self, query: str, passage: list[str]) -> list[float]: return self.model.predict([[query, p] for p in passage]).tolist()