File size: 2,132 Bytes
6505378
 
 
 
 
 
 
 
 
 
4c31c97
6505378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c31c97
6505378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c31c97
6505378
 
 
 
4c31c97
f2f47ac
4c31c97
 
 
 
6505378
 
 
4c31c97
 
 
6505378
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging
import pickle

import matplotlib.pyplot as plt
import polars as pl
import seaborn as sns
from numpy.typing import NDArray
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.svm import SVC

from utils.paths import DATA, IMGS, MODEL

logging.basicConfig(level=logging.INFO)


def save_roc_curve(clf, X: NDArray, y: NDArray):
    probs = clf.predict_proba(X)[:, 1]  # Probability for the positive class
    fpr, tpr, thresholds = roc_curve(y, probs)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 5))
    plt.plot(
        fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})"
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic (ROC)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(IMGS / "roc_curve.png")
    plt.close()


def save_confusion_matrix(y: NDArray, pred: NDArray):
    plt.figure(figsize=(5, 4))
    sns.heatmap(
        confusion_matrix(y, pred),
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Not Relevant", "Relevant"],
        yticklabels=["Not Relevant", "Relevant"],
    )
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(IMGS / "confusion_matrix.png")
    plt.close()


def main() -> None:
    train_df = pl.read_parquet(DATA / "train.parquet")
    clf = SVC(kernel="poly", probability=True)
    clf.fit(
        train_df.get_column("embeds").to_numpy(),
        train_df.get_column("is_news").to_numpy(),
    )
    with open(MODEL / "model.pickle", "wb") as f:
        pickle.dump(clf, f)

    eval_df = pl.read_parquet(DATA / "eval.parquet")
    eval_X = eval_df.get_column("embeds").to_numpy()
    eval_y = eval_df.get_column("is_news").to_numpy()
    eval_pred = clf.predict(eval_X)
    save_confusion_matrix(eval_y, eval_pred)
    save_roc_curve(clf, eval_X, eval_y)


if __name__ == "__main__":
    main()