|
# Zero-Shot Text Classification using `facebook/bart-large-mnli` |
|
|
|
This repository demonstrates how to use the [`facebook/bart-large-mnli`](https://huggingface.co/facebook/bart-large-mnli) model for **zero-shot text classification** based on **natural language inference (NLI)**. |
|
|
|
We extend the base usage by: |
|
- Using a labeled dataset for benchmarking |
|
- Performing optional fine-tuning |
|
- Quantizing the model to FP16 |
|
- Scoring model performance |
|
|
|
--- |
|
|
|
## π Model Description |
|
|
|
- **Model:** `facebook/bart-large-mnli` |
|
- **Type:** NLI-based zero-shot classifier |
|
- **Architecture:** BART (Bidirectional and Auto-Regressive Transformers) |
|
- **Usage:** Classifies text by scoring label hypotheses as NLI entailment |
|
|
|
--- |
|
|
|
## π Dataset |
|
|
|
We use the [`yahoo_answers_topics`](https://huggingface.co/datasets/yahoo_answers_topics) dataset from Hugging Face for evaluation. It contains questions categorized into 10 topics. |
|
|
|
```python |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("yahoo_answers_topics") |
|
``` |
|
|
|
# π§ Zero-Shot Classification Logic |
|
The model checks whether a text entails a hypothesis like: |
|
|
|
"This text is about sports." |
|
|
|
For each candidate label (e.g., "sports", "education", "health"), we convert them into such hypotheses and use the model to score them. |
|
|
|
# β
Example: Inference with Zero-Shot Pipeline |
|
```python |
|
from transformers import pipeline |
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
sequence = "The team played well and won the championship." |
|
labels = ["sports", "politics", "education", "technology"] |
|
|
|
result = classifier(sequence, candidate_labels=labels) |
|
print(result) |
|
``` |
|
|
|
# π Scoring / Evaluation |
|
Evaluate zero-shot classification using accuracy or top-k accuracy: |
|
|
|
```python |
|
from sklearn.metrics import accuracy_score |
|
|
|
def evaluate_zero_shot(dataset, labels): |
|
correct = 0 |
|
total = 0 |
|
for example in dataset: |
|
result = classifier(example["question_content"], candidate_labels=labels) |
|
predicted = result["labels"][0] |
|
true = labels[example["topic"]] |
|
correct += int(predicted == true) |
|
total += 1 |
|
return correct / total |
|
|
|
labels = ["Society & Culture", "Science & Mathematics", "Health", "Education", |
|
"Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music", |
|
"Family & Relationships", "Politics & Government"] |
|
|
|
acc = evaluate_zero_shot(dataset["test"].select(range(100)), labels) |
|
print(f"Accuracy: {acc:.2%}") |
|
``` |