|
import logging |
|
import pickle |
|
|
|
import matplotlib.pyplot as plt |
|
import polars as pl |
|
import seaborn as sns |
|
from numpy.typing import NDArray |
|
from sklearn.metrics import auc, confusion_matrix, roc_curve |
|
from sklearn.svm import SVC |
|
|
|
from utils.paths import DATA, IMGS, MODEL |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def save_roc_curve(clf, X: NDArray, y: NDArray): |
|
probs = clf.predict_proba(X)[:, 1] |
|
fpr, tpr, thresholds = roc_curve(y, probs) |
|
roc_auc = auc(fpr, tpr) |
|
|
|
plt.figure(figsize=(6, 5)) |
|
plt.plot( |
|
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})" |
|
) |
|
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
plt.title("Receiver Operating Characteristic (ROC)") |
|
plt.legend(loc="lower right") |
|
plt.tight_layout() |
|
plt.savefig(IMGS / "roc_curve.png") |
|
plt.close() |
|
|
|
|
|
def save_confusion_matrix(y: NDArray, pred: NDArray): |
|
plt.figure(figsize=(5, 4)) |
|
sns.heatmap( |
|
confusion_matrix(y, pred), |
|
annot=True, |
|
fmt="d", |
|
cmap="Blues", |
|
xticklabels=["Not Relevant", "Relevant"], |
|
yticklabels=["Not Relevant", "Relevant"], |
|
) |
|
plt.xlabel("Predicted") |
|
plt.ylabel("Actual") |
|
plt.title("Confusion Matrix") |
|
plt.tight_layout() |
|
plt.savefig(IMGS / "confusion_matrix.png") |
|
plt.close() |
|
|
|
|
|
def main() -> None: |
|
train_df = pl.read_parquet(DATA / "train.parquet") |
|
clf = SVC(kernel="poly", probability=True) |
|
clf.fit( |
|
train_df.get_column("embeds").to_numpy(), |
|
train_df.get_column("is_news").to_numpy(), |
|
) |
|
with open(MODEL / "model.pickle", "wb") as f: |
|
pickle.dump(clf, f) |
|
|
|
eval_df = pl.read_parquet(DATA / "eval.parquet") |
|
eval_X = eval_df.get_column("embeds").to_numpy() |
|
eval_y = eval_df.get_column("is_news").to_numpy() |
|
eval_pred = clf.predict(eval_X) |
|
save_confusion_matrix(eval_y, eval_pred) |
|
save_roc_curve(clf, eval_X, eval_y) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|