radio-embed-p / radio_embed_p /radio_embed_p.py
freemt
Update radio_embed_p
fb4c8c6
"""Embed input."""
# pylint: disable=invalid-name
import multiprocessing as mp
from math import ceil
import more_itertools as mit
import numpy
from hf_model_s_cpu import model_s
from logzero import logger
try:
model = model_s()
except Exception as _:
logger.exception(_)
raise SystemExit(1) from _
num_cpus = mp.cpu_count()
def radio_embed(
text: str,
) -> numpy.ndarray:
"""Embed input."""
try:
_ = model.encode(text.strip().splitlines())
except Exception as _:
logger.exception(_)
raise
return _
def radio_embed_p(
text: str,
) -> numpy.ndarray:
"""Embed input in parallel."""
# split evenly to num_cpus parts
_ = text.splitlines()
_ = mit.chunked_even(_, ceil(len(_) / num_cpus))
# back to str for radio_embed
args = ["\n".join(elm) for elm in _]
try:
with mp.Pool(num_cpus) as pool:
# ret = pool.map(func, args_)
ret = pool.map(radio_embed, args)
except Exception as exc:
logger.error(exc)
raise
return ret