File size: 2,488 Bytes
adc3cce
2275eaa
adc3cce
2275eaa
adc3cce
 
 
 
 
2275eaa
 
 
adc3cce
2275eaa
 
adc3cce
 
 
2275eaa
 
 
adc3cce
2275eaa
adc3cce
2275eaa
 
 
 
adc3cce
2275eaa
 
adc3cce
 
 
 
 
 
 
 
 
2275eaa
 
 
 
adc3cce
 
2275eaa
adc3cce
2275eaa
 
 
adc3cce
 
2275eaa
 
 
 
adc3cce
2275eaa
 
adc3cce
 
2275eaa
adc3cce
 
2275eaa
adc3cce
2275eaa
adc3cce
 
 
2275eaa
adc3cce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Zero-Shot Text Classification using `facebook/bart-large-mnli`

This repository demonstrates how to use the [`facebook/bart-large-mnli`](https://huggingface.co/facebook/bart-large-mnli) model for **zero-shot text classification** based on **natural language inference (NLI)**.

We extend the base usage by:
- Using a labeled dataset for benchmarking
- Performing optional fine-tuning
- Quantizing the model to FP16
- Scoring model performance

---

## πŸ“Œ Model Description

- **Model:** `facebook/bart-large-mnli`
- **Type:** NLI-based zero-shot classifier
- **Architecture:** BART (Bidirectional and Auto-Regressive Transformers)
- **Usage:** Classifies text by scoring label hypotheses as NLI entailment

---

## πŸ“‚ Dataset

We use the [`yahoo_answers_topics`](https://huggingface.co/datasets/yahoo_answers_topics) dataset from Hugging Face for evaluation. It contains questions categorized into 10 topics.

```python
from datasets import load_dataset

dataset = load_dataset("yahoo_answers_topics")
```

# 🧠 Zero-Shot Classification Logic
The model checks whether a text entails a hypothesis like:

"This text is about sports."

For each candidate label (e.g., "sports", "education", "health"), we convert them into such hypotheses and use the model to score them.

# βœ… Example: Inference with Zero-Shot Pipeline
```python
from transformers import pipeline

classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

sequence = "The team played well and won the championship."
labels = ["sports", "politics", "education", "technology"]

result = classifier(sequence, candidate_labels=labels)
print(result)
```

# πŸ“Š Scoring / Evaluation
Evaluate zero-shot classification using accuracy or top-k accuracy:

```python
from sklearn.metrics import accuracy_score

def evaluate_zero_shot(dataset, labels):
    correct = 0
    total = 0
    for example in dataset:
        result = classifier(example["question_content"], candidate_labels=labels)
        predicted = result["labels"][0]
        true = labels[example["topic"]]
        correct += int(predicted == true)
        total += 1
    return correct / total

labels = ["Society & Culture", "Science & Mathematics", "Health", "Education", 
          "Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music", 
          "Family & Relationships", "Politics & Government"]

acc = evaluate_zero_shot(dataset["test"].select(range(100)), labels)
print(f"Accuracy: {acc:.2%}")
```