File size: 960 Bytes
231da5b 4c31c97 231da5b 4c31c97 231da5b 4c31c97 f2f47ac 6ffe4f3 f2f47ac 6ffe4f3 231da5b 4c31c97 6ffe4f3 231da5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import pickle
from functools import cache
import polars as pl
from huggingface_hub import hf_hub_download
from utils.embed import embed
from utils.paths import DATA
@cache
def get_model():
file_name = hf_hub_download(
"opale-ai/news-classifier", "model/model.pickle", revision="main"
)
with open(file_name, "rb") as f:
return pickle.load(f)
def get_record():
df = pl.read_csv(DATA / "eval.csv")
return {col: val for col, val in zip(df.columns, df.sample().row(0))}
def pred_record(rec):
text_fields = ["meta_title", "meta_description", "content"]
text = "\n\n".join(rec[k] for k in text_fields)
embeds = embed([text])
(pred,) = get_model().predict(embeds)
return pred
def main():
record = get_record()
is_news = record["is_news_article"]
pred = pred_record(record)
print(f"is news (real): {is_news}")
print(f"is news (pred): {pred}")
if __name__ == "__main__":
main()
|