# Zero-Shot Text Classification using `facebook/bart-large-mnli` This repository demonstrates how to use the [`facebook/bart-large-mnli`](https://huggingface.co/facebook/bart-large-mnli) model for **zero-shot text classification** based on **natural language inference (NLI)**. We extend the base usage by: - Using a labeled dataset for benchmarking - Performing optional fine-tuning - Quantizing the model to FP16 - Scoring model performance --- ## 📌 Model Description - **Model:** `facebook/bart-large-mnli` - **Type:** NLI-based zero-shot classifier - **Architecture:** BART (Bidirectional and Auto-Regressive Transformers) - **Usage:** Classifies text by scoring label hypotheses as NLI entailment --- ## 📂 Dataset We use the [`yahoo_answers_topics`](https://huggingface.co/datasets/yahoo_answers_topics) dataset from Hugging Face for evaluation. It contains questions categorized into 10 topics. ```python from datasets import load_dataset dataset = load_dataset("yahoo_answers_topics") ``` # 🧠 Zero-Shot Classification Logic The model checks whether a text entails a hypothesis like: "This text is about sports." For each candidate label (e.g., "sports", "education", "health"), we convert them into such hypotheses and use the model to score them. # ✅ Example: Inference with Zero-Shot Pipeline ```python from transformers import pipeline classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") sequence = "The team played well and won the championship." labels = ["sports", "politics", "education", "technology"] result = classifier(sequence, candidate_labels=labels) print(result) ``` # 📊 Scoring / Evaluation Evaluate zero-shot classification using accuracy or top-k accuracy: ```python from sklearn.metrics import accuracy_score def evaluate_zero_shot(dataset, labels): correct = 0 total = 0 for example in dataset: result = classifier(example["question_content"], candidate_labels=labels) predicted = result["labels"][0] true = labels[example["topic"]] correct += int(predicted == true) total += 1 return correct / total labels = ["Society & Culture", "Science & Mathematics", "Health", "Education", "Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music", "Family & Relationships", "Politics & Government"] acc = evaluate_zero_shot(dataset["test"].select(range(100)), labels) print(f"Accuracy: {acc:.2%}") ```