File size: 1,152 Bytes
4c31c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import logging

import polars as pl

from utils.embed import embed as embed
from utils.paths import DATA

logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO)


def load_dataset(file_name: str):
    features = ["meta_title", "meta_description", "content"]
    return (
        pl.scan_csv(file_name)
        .with_columns(
            pl.concat_str([pl.col(c) for c in features], separator="\n\n").alias(
                "text"
            ),
            pl.col("date").str.to_date().alias("date"),
        )
        .rename(
            {
                "is_news_article": "is_news",
                "link_count": "links",
                "paragraph_count": "paragraphs",
            }
        )
        .select("text", "is_news", "url", "date", "paragraphs", "links")
        .collect()
    )


def main() -> None:
    for name in ["train", "eval"]:
        df = load_dataset(DATA / (name + ".csv"))
        embeds = embed(df.get_column("text").to_list())
        df = df.with_columns(pl.Series(embeds).alias("embeds")).write_parquet(
            DATA / (name + ".parquet")
        )


if __name__ == "__main__":
    main()