import os from functools import cache from itertools import batched from typing import Generator, Iterator import numpy as np from numpy.typing import NDArray from sentence_transformers import SentenceTransformer def split(text: str, max_tokens: int = 512) -> Generator[str, None, None]: # Naive approach - use opale internal chunking techniques (special tokens count) words = text.split() if not words: return for batch in batched(words, max_tokens // 2): # Assuming 2 tokens per word yield " ".join(batch) @cache def get_model(): return SentenceTransformer( os.environ["EMBEDDING_MODEL"], revision=os.environ["EMBEDDING_MODEL_REV"] ) def embed(texts: Iterator[str], max_tokens: int = 512) -> NDArray: res: list[NDArray] = [] for text in texts: embeddings = get_model().encode( list(split(text, max_tokens)), show_progress_bar=False ) res.append(np.mean(embeddings, axis=0)) return np.array(res)