Spaces:
Runtime error
Runtime error
File size: 1,535 Bytes
8b7a023 eabd180 8b7a023 bf962b1 8b7a023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import json
import asyncio
import logging
import time
from tqdm.asyncio import tqdm_asyncio
from huggingface_hub import get_inference_endpoint
from models import env_config, embed_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)
async def embed_chunk(sentence, semaphore, tmp_file):
async with semaphore:
payload = {
"inputs": sentence,
"truncate": True
}
try:
resp = await endpoint.async_client.post(json=payload)
except Exception as e:
raise RuntimeError(str(e))
result = json.loads(resp)
tmp_file.write(
json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
)
async def embed_wrapper(input_ds, temp_file):
endpoint.fetch()
semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
jobs = [
asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
for row in input_ds if row[env_config.input_text_col].strip()
]
logger.info(f"num chunks to embed: {len(jobs)}")
tic = time.time()
await tqdm_asyncio.gather(*jobs)
logger.info(f"embed time: {time.time() - tic}")
def wake_up_endpoint():
endpoint.fetch()
if endpoint.status != 'running':
logger.info("Starting up TEI endpoint")
endpoint.resume().wait()
logger.info("TEI endpoint is up")
return
|