Spaces:
Sleeping
Sleeping
File size: 2,371 Bytes
fbf7e95 c19ce61 fbf7e95 c19ce61 fbf7e95 c19ce61 fbf7e95 c19ce61 fbf7e95 5381b52 fbf7e95 5381b52 d29e6b9 fbf7e95 d29e6b9 fbf7e95 d29e6b9 fbf7e95 d29e6b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import argparse
import csv
import pandas as pd
from tqdm import tqdm
from marcai.predict import predict_onnx
from marcai.process import multiprocess_pairs
from marcai.utils import load_config
from marcai.utils.parsing import load_records, record_dict
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
parser.add_argument(
"-p",
"--pair-indices",
help="File containing indices of comparisons",
required=True,
)
parser.add_argument("-C", "--chunksize", help="Chunk size", type=int, default=50000)
parser.add_argument(
"-P", "--processes", help="Number of processes", type=int, default=1
)
parser.add_argument(
"-m",
"--model-dir",
help="Directory containing model ONNX and YAML files",
required=True,
)
parser.add_argument("-o", "--output", help="Output file", required=True)
parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
return parser
def main(args):
config_path = f"{args.model_dir}/config.yaml"
model_onnx = f"{args.model_dir}/model.onnx"
config = load_config(config_path)
# Load records
print("Loading records...")
records = []
for path in args.inputs:
records.extend([record_dict(r) for r in load_records(path)])
records_df = pd.DataFrame(records)
print(f"Loaded {len(records)} records.")
print("Processing and comparing records...")
written = False
with open(args.pair_indices, "r") as indices_file:
reader = csv.reader(indices_file)
# Process records
for df in tqdm(
multiprocess_pairs(records_df, reader, args.chunksize, args.processes)
):
input_df = df[config["model"]["features"]]
prediction = predict_onnx(model_onnx, input_df)
df.loc[:, "prediction"] = prediction.squeeze()
df = df[df["prediction"] >= args.threshold]
if not df.empty:
if not written:
df.to_csv(args.output, index=False)
written = True
else:
df.to_csv(args.output, index=False, mode="a", header=False)
if __name__ == "__main__":
args = args_parser().parse_args()
main(args)
|