Ali Kefia
ok
4c31c97
import os
from functools import cache
from itertools import batched
from typing import Generator, Iterator
import numpy as np
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer
def split(text: str, max_tokens: int = 512) -> Generator[str, None, None]:
# Naive approach - use opale internal chunking techniques (special tokens count)
words = text.split()
if not words:
return
for batch in batched(words, max_tokens // 2): # Assuming 2 tokens per word
yield " ".join(batch)
@cache
def get_model():
return SentenceTransformer(
os.environ["EMBEDDING_MODEL"], revision=os.environ["EMBEDDING_MODEL_REV"]
)
def embed(texts: Iterator[str], max_tokens: int = 512) -> NDArray:
res: list[NDArray] = []
for text in texts:
embeddings = get_model().encode(
list(split(text, max_tokens)), show_progress_bar=False
)
res.append(np.mean(embeddings, axis=0))
return np.array(res)