news-classifier / prepare.py
Ali Kefia
ok
4c31c97
import logging
import polars as pl
from utils.embed import embed as embed
from utils.paths import DATA
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def load_dataset(file_name: str):
features = ["meta_title", "meta_description", "content"]
return (
pl.scan_csv(file_name)
.with_columns(
pl.concat_str([pl.col(c) for c in features], separator="\n\n").alias(
"text"
),
pl.col("date").str.to_date().alias("date"),
)
.rename(
{
"is_news_article": "is_news",
"link_count": "links",
"paragraph_count": "paragraphs",
}
)
.select("text", "is_news", "url", "date", "paragraphs", "links")
.collect()
)
def main() -> None:
for name in ["train", "eval"]:
df = load_dataset(DATA / (name + ".csv"))
embeds = embed(df.get_column("text").to_list())
df = df.with_columns(pl.Series(embeds).alias("embeds")).write_parquet(
DATA / (name + ".parquet")
)
if __name__ == "__main__":
main()