|
import os |
|
from functools import cache |
|
from itertools import batched |
|
from typing import Generator, Iterator |
|
|
|
import numpy as np |
|
from numpy.typing import NDArray |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
def split(text: str, max_tokens: int = 512) -> Generator[str, None, None]: |
|
|
|
words = text.split() |
|
if not words: |
|
return |
|
for batch in batched(words, max_tokens // 2): |
|
yield " ".join(batch) |
|
|
|
|
|
@cache |
|
def get_model(): |
|
return SentenceTransformer( |
|
os.environ["EMBEDDING_MODEL"], revision=os.environ["EMBEDDING_MODEL_REV"] |
|
) |
|
|
|
|
|
def embed(texts: Iterator[str], max_tokens: int = 512) -> NDArray: |
|
res: list[NDArray] = [] |
|
for text in texts: |
|
embeddings = get_model().encode( |
|
list(split(text, max_tokens)), show_progress_bar=False |
|
) |
|
res.append(np.mean(embeddings, axis=0)) |
|
return np.array(res) |
|
|