|
import logging |
|
|
|
import polars as pl |
|
|
|
from utils.embed import embed as embed |
|
from utils.paths import DATA |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def load_dataset(file_name: str): |
|
features = ["meta_title", "meta_description", "content"] |
|
return ( |
|
pl.scan_csv(file_name) |
|
.with_columns( |
|
pl.concat_str([pl.col(c) for c in features], separator="\n\n").alias( |
|
"text" |
|
), |
|
pl.col("date").str.to_date().alias("date"), |
|
) |
|
.rename( |
|
{ |
|
"is_news_article": "is_news", |
|
"link_count": "links", |
|
"paragraph_count": "paragraphs", |
|
} |
|
) |
|
.select("text", "is_news", "url", "date", "paragraphs", "links") |
|
.collect() |
|
) |
|
|
|
|
|
def main() -> None: |
|
for name in ["train", "eval"]: |
|
df = load_dataset(DATA / (name + ".csv")) |
|
embeds = embed(df.get_column("text").to_list()) |
|
df = df.with_columns(pl.Series(embeds).alias("embeds")).write_parquet( |
|
DATA / (name + ".parquet") |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|