|
import pickle |
|
from functools import cache |
|
|
|
import polars as pl |
|
from huggingface_hub import hf_hub_download |
|
|
|
from utils.embed import embed |
|
from utils.paths import DATA |
|
|
|
|
|
@cache |
|
def get_model(): |
|
file_name = hf_hub_download( |
|
"opale-ai/news-classifier", "model/model.pickle", revision="main" |
|
) |
|
with open(file_name, "rb") as f: |
|
return pickle.load(f) |
|
|
|
|
|
def get_record(): |
|
df = pl.read_csv(DATA / "eval.csv") |
|
return {col: val for col, val in zip(df.columns, df.sample().row(0))} |
|
|
|
|
|
def pred_record(rec): |
|
text_fields = ["meta_title", "meta_description", "content"] |
|
text = "\n\n".join(rec[k] for k in text_fields) |
|
embeds = embed([text]) |
|
(pred,) = get_model().predict(embeds) |
|
return pred |
|
|
|
|
|
def main(): |
|
record = get_record() |
|
is_news = record["is_news_article"] |
|
pred = pred_record(record) |
|
print(f"is news (real): {is_news}") |
|
print(f"is news (pred): {pred}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|