Spaces:
Runtime error
Runtime error
"""Embed input.""" | |
# pylint: disable=invalid-name | |
import multiprocessing as mp | |
from math import ceil | |
import more_itertools as mit | |
import numpy | |
from hf_model_s_cpu import model_s | |
from logzero import logger | |
try: | |
model = model_s() | |
except Exception as _: | |
logger.exception(_) | |
raise SystemExit(1) from _ | |
num_cpus = mp.cpu_count() | |
def radio_embed( | |
text: str, | |
) -> numpy.ndarray: | |
"""Embed input.""" | |
try: | |
_ = model.encode(text.strip().splitlines()) | |
except Exception as _: | |
logger.exception(_) | |
raise | |
return _ | |
def radio_embed_p( | |
text: str, | |
) -> numpy.ndarray: | |
"""Embed input in parallel.""" | |
# split evenly to num_cpus parts | |
_ = text.splitlines() | |
_ = mit.chunked_even(_, ceil(len(_) / num_cpus)) | |
# back to str for radio_embed | |
args = ["\n".join(elm) for elm in _] | |
try: | |
with mp.Pool(num_cpus) as pool: | |
# ret = pool.map(func, args_) | |
ret = pool.map(radio_embed, args) | |
except Exception as exc: | |
logger.error(exc) | |
raise | |
return ret | |