AmanSengar's picture
Create README.md
4405582 verified

🧠 SMSDetection-DistilBERT-SMS

A DistilBERT-based binary classifier fine-tuned on the SMS Spam Collection dataset. It classifies messages as either spam or ham (not spam). This model is suitable for real-world applications like mobile SMS spam filters, automated customer message triage, and telecom fraud detection.


✨ Model Highlights

  • πŸ“Œ Based on distilbert-base-uncased
  • πŸ” Fine-tuned on the SMS Spam Collection dataset
  • ⚑ Supports binary classification: Spam vs Not Spam
  • πŸ’Ύ Lightweight and optimized for both CPU and GPU environments

🧠 Intended Uses

  • βœ… Mobile SMS spam filtering
  • βœ… Telecom customer service automation
  • βœ… Fraudulent message detection
  • βœ… User inbox categorization
  • βœ… Regulatory compliance monitoring

  • 🚫 Limitations

  • ❌ Trained on English SMS messages only

  • ❌ May underperform on emails, social media texts, or non-English content

  • ❌ Not designed for multilingual datasets

  • ❌ Slight performance dip expected for long messages (>128 tokens)


πŸ‹οΈβ€β™‚οΈ Training Details

Field Value
Base Model distilbert-base-uncased
Dataset SMS Spam Collection (UCI)
Framework PyTorch with πŸ€— Transformers
Epochs 3
Batch Size 16
Max Length 128 tokens
Optimizer AdamW
Loss CrossEntropyLoss (token-level)
Device Trained on CUDA-enabled GPU

πŸ“Š Evaluation Metrics

Metric Score
Accuracy 0.99
F1-Score 0.96
Precision 0.98
Recall 0.93


πŸš€ Usage

from transformers import BertTokenizerFast, BertForTokenClassification
from transformers import pipeline
import torch

model_name = "AventIQ-AI/SMS-Spam-Detection-Model"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name)
model.eval()


# Inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def predict_sms(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted = torch.argmax(logits, dim=1).item()
    return "spam" if predicted == 1 else "ham"

# Test example
print(predict_sms("You've won $1,000,000! Call now to claim your prize!"))

  • 🧩 Quantization
  • Post-training static quantization applied using PyTorch to reduce model size and accelerate inference on edge devices.

πŸ—‚ Repository Structure

.
β”œβ”€β”€ model/               # Quantized model files
β”œβ”€β”€ tokenizer_config/    # Tokenizer and vocab files
β”œβ”€β”€ model.safensors/     # Fine-tuned model in safetensors format
β”œβ”€β”€ README.md            # Model card

🀝 Contributing

Open to improvements and feedback! Feel free to submit a pull request or open an issue if you find any bugs or want to enhance the model.