nimishgarg commited on
Commit
929f617
Β·
verified Β·
1 Parent(s): 7d1b62e

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +112 -0
README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DistilBERT-Base-Uncased Quantized Model for Spam Detection
2
+
3
+ This repository hosts a quantized version of the DistilBERT model, fine-tuned for spam classification using a labeled SMS dataset. The model has been optimized using FP16 quantization for efficient deployment without significant accuracy loss.
4
+
5
+ ## Model Details
6
+
7
+ - **Model Architecture:** DistilBERT Base Uncased
8
+ - **Task:** Binary Spam Classification (Spam/Ham)
9
+ - **Dataset:** SMS Spam Collection
10
+ - **Quantization:** Float16
11
+ - **Fine-tuning Framework:** Hugging Face Transformers
12
+
13
+ ---
14
+
15
+ ## Installation
16
+
17
+ ```bash
18
+ pip install transformers datasets scikit-learn
19
+ ```
20
+
21
+ ---
22
+
23
+ ## Loading the Model
24
+
25
+ ```python
26
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
27
+ import torch
28
+
29
+ # Load tokenizer and model
30
+ model_path = "distilbert-base-uncased"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
32
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
33
+
34
+ # Define test messages
35
+ texts = [
36
+ "Congratulations! You have won a free iPhone. Click here to claim your prize.",
37
+ "Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat..."
38
+ ]
39
+
40
+ # Tokenize and predict
41
+ for text in texts:
42
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
43
+ inputs = {k: v.long() for k, v in inputs.items()}
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
47
+ label_map = {0: "Ham", 1: "Spam"}
48
+ print(f"Text: {text}")
49
+ print(f"Predicted Label: {label_map[predicted_class]}\n")
50
+ ```
51
+
52
+ ---
53
+
54
+ ## Performance Metrics
55
+
56
+ - **Accuracy:** 0.9994
57
+ - **Precision:** 1.0000
58
+ - **Recall:** 0.9955
59
+ - **F1 Score:** 0.9978
60
+
61
+ ---
62
+
63
+ ## Fine-Tuning Details
64
+
65
+ ### Dataset
66
+
67
+ The dataset used is the SMS Spam Collection dataset containing labeled messages as either "spam" or "ham".
68
+ The dataset was cleaned using custom preprocessing, then split into 80% training and 20% validation sets with stratification.
69
+
70
+ ### Training
71
+
72
+ - **Epochs:** 5
73
+ - **Batch size:** 12 (train) / 16 (eval)
74
+ - **Learning rate:** 3e-5
75
+ - **Evaluation strategy:** `epoch`
76
+ - **FP16 Training:** Enabled
77
+ - **Trainer:** Hugging Face `Trainer` API
78
+
79
+ ---
80
+
81
+ ## Quantization
82
+
83
+ Post-training quantization was applied using `model.to(dtype=torch.float16)` to reduce model size and speed up inference.
84
+
85
+ ---
86
+
87
+ ## Repository Structure
88
+
89
+ ```bash
90
+ .
91
+ β”œβ”€β”€ quantized-model/ # Contains the quantized model files
92
+ β”‚ β”œβ”€β”€ config.json
93
+ β”‚ β”œβ”€β”€ model.safetensors
94
+ β”‚ β”œβ”€β”€ tokenizer_config.json
95
+ β”‚ β”œβ”€β”€ vocab.txt
96
+ β”‚ └── special_tokens_map.json
97
+ β”œβ”€β”€ README.md # Project documentation
98
+ ```
99
+
100
+ ---
101
+
102
+ ## Limitations
103
+
104
+ - The model is trained specifically for binary spam classification on SMS data.
105
+ - Performance might degrade when applied to emails or social media without domain adaptation.
106
+ - FP16 inference might show slight instability on edge cases.
107
+
108
+ ---
109
+
110
+ ## Contributing
111
+
112
+ Feel free to open issues or submit pull requests to improve the model, training process, or documentation.