Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve | |
| from sklearn.preprocessing import label_binarize | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from datasets import load_dataset | |
| # Check for CUDA | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load dataset | |
| dataset = load_dataset("clinc_oos", "plus") | |
| label_names = dataset["train"].features["intent"].names # Ensure correct order | |
| # Load model | |
| num_labels = len(label_names) | |
| model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels) | |
| model.load_state_dict(torch.load("intent_classifier.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Load tokenizer | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # Prepare data | |
| true_labels = [] | |
| pred_labels = [] | |
| all_probs = [] | |
| for example in dataset["test"]: | |
| sentence = example["text"] | |
| true_label = example["intent"] | |
| # Tokenize | |
| inputs = tokenizer(sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=128) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1).cpu().numpy()[0] | |
| predicted_class = np.argmax(probs) | |
| # Store results | |
| true_labels.append(true_label) | |
| pred_labels.append(predicted_class) | |
| all_probs.append(probs) | |
| # Convert to numpy arrays | |
| true_labels = np.array(true_labels) | |
| pred_labels = np.array(pred_labels) | |
| all_probs = np.array(all_probs) | |
| # Compute confusion matrix | |
| conf_matrix = confusion_matrix(true_labels, pred_labels) | |
| # Plot confusion matrix | |
| plt.figure(figsize=(12, 10)) | |
| sns.heatmap(conf_matrix, annot=False, fmt="d", cmap="Blues") | |
| plt.xlabel("Predicted Label") | |
| plt.ylabel("True Label") | |
| plt.title("Confusion Matrix for Intent Classification") | |
| plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print("Confusion matrix saved as confusion_matrix.png") | |
| # --- Multi-Class Precision-Recall Curve --- | |
| # Binarize true labels for multi-class PR calculation | |
| true_labels_bin = label_binarize(true_labels, classes=np.arange(num_labels)) | |
| # Plot Precision-Recall Curve for multiple classes | |
| plt.figure(figsize=(10, 8)) | |
| for i in range(num_labels): | |
| precision, recall, _ = precision_recall_curve(true_labels_bin[:, i], all_probs[:, i]) | |
| plt.plot(recall, precision, lw=1, alpha=0.7, label=f"Class {i}: {label_names[i]}") | |
| plt.xlabel("Recall") | |
| plt.ylabel("Precision") | |
| plt.title("Multi-Class Precision-Recall Curve") | |
| plt.legend(loc="best", fontsize=6, ncol=2, frameon=True) | |
| plt.grid(True) | |
| plt.savefig("precision_recall_curve.png", dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print("Precision-Recall curve saved as precision_recall_curve.png") | |