news-classifier / train.py
Ali Kefia
linear -> poly
f2f47ac
import logging
import pickle
import matplotlib.pyplot as plt
import polars as pl
import seaborn as sns
from numpy.typing import NDArray
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.svm import SVC
from utils.paths import DATA, IMGS, MODEL
logging.basicConfig(level=logging.INFO)
def save_roc_curve(clf, X: NDArray, y: NDArray):
probs = clf.predict_proba(X)[:, 1] # Probability for the positive class
fpr, tpr, thresholds = roc_curve(y, probs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 5))
plt.plot(
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(IMGS / "roc_curve.png")
plt.close()
def save_confusion_matrix(y: NDArray, pred: NDArray):
plt.figure(figsize=(5, 4))
sns.heatmap(
confusion_matrix(y, pred),
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Not Relevant", "Relevant"],
yticklabels=["Not Relevant", "Relevant"],
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(IMGS / "confusion_matrix.png")
plt.close()
def main() -> None:
train_df = pl.read_parquet(DATA / "train.parquet")
clf = SVC(kernel="poly", probability=True)
clf.fit(
train_df.get_column("embeds").to_numpy(),
train_df.get_column("is_news").to_numpy(),
)
with open(MODEL / "model.pickle", "wb") as f:
pickle.dump(clf, f)
eval_df = pl.read_parquet(DATA / "eval.parquet")
eval_X = eval_df.get_column("embeds").to_numpy()
eval_y = eval_df.get_column("is_news").to_numpy()
eval_pred = clf.predict(eval_X)
save_confusion_matrix(eval_y, eval_pred)
save_roc_curve(clf, eval_X, eval_y)
if __name__ == "__main__":
main()