Spaces:
Running
Running
File size: 1,076 Bytes
a3ee979 07a2d78 a3ee979 07a2d78 a3ee979 07a2d78 a3ee979 fcf6714 a3ee979 07a2d78 a3ee979 07a2d78 a3ee979 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from pathlib import Path
from typing import Optional
import numpy as np
import pyarrow as pa
class ArrowMetadataProvider:
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
Code taken from: https://github.dev/rom1504/clip-retrieval
"""
def __init__(self, arrow_folder: Path):
arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
self.table = pa.concat_tables(
[
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
for arrow_file in arrow_files
]
)
def get(self, ids: np.ndarray, cols: Optional[list] = None):
"""Implement the get method from the arrow metadata provide, get metadata from ids."""
if cols is None:
cols = self.table.schema.names
else:
cols = list(set(self.table.schema.names) & set(cols))
t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
return t.select(cols).to_pandas().to_dict("records")
|