|
# 🧠 Image Classification AI Model (CIFAR-100) |
|
|
|
This repository contains a Vision Transformer (ViT)-based AI model fine-tuned for **image classification** on the CIFAR-100 dataset. The model is built using `google/vit-base-patch16-224`, quantized to **FP16** for efficient inference, and delivers high accuracy in multi-class image classification tasks. |
|
|
|
--- |
|
|
|
## 🚀 Features |
|
|
|
- 🖼️ **Task**: Image Classification |
|
- 🧠 **Base Model**: `google/vit-base-patch16-224` (Vision Transformer) |
|
- 🧪 **Quantized**: FP16 for faster and memory-efficient inference |
|
- 🎯 **Dataset**: 100 fine-grained object categories |
|
- ⚡ **CUDA Enabled**: Optimized for GPU acceleration |
|
- 📈 **High Accuracy**: Fine-tuned and evaluated on validation split |
|
|
|
--- |
|
|
|
## 📊 Dataset Used |
|
|
|
**Hugging Face Dataset**: [`tanganke/cifar100`](https://huggingface.co/datasets/tanganke/cifar100) |
|
|
|
- **Description**: CIFAR-100 is a dataset of 60,000 32×32 color images in 100 classes (600 images per class) |
|
- **Split**: 50,000 training images and 10,000 test images |
|
- **Categories**: Animals, Vehicles, Food, Household items, etc. |
|
- **License**: MIT License (from source) |
|
|
|
```python |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("tanganke/cifar100") |
|
``` |
|
|
|
## 🛠️ Model & Training Configuration |
|
|
|
- Model: google/vit-base-patch16-224 |
|
|
|
- Image Size: 224x224 (resized from 32x32) |
|
|
|
- Framework: Hugging Face Transformers & Datasets |
|
|
|
- Training Environment: Kaggle Notebook with CUDA |
|
|
|
- Epochs: 5–10 (with early stopping) |
|
|
|
- Batch Size: 32 |
|
|
|
- Optimizer: AdamW |
|
|
|
- Loss Function: CrossEntropyLoss |
|
|
|
# ✅ Evaluation & Scoring |
|
|
|
- Accuracy: ~70–80% (varies by configuration) |
|
|
|
- Validation Tool: evaluate or sklearn.metrics |
|
|
|
- Metric: Accuracy, Top-1 and Top-5 scores |
|
|
|
- Inference Speed: Significantly faster after quantizationextractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") |
|
|
|
# 🔍 Inference Example |
|
|
|
```python |
|
from PIL import Image |
|
import torch |
|
|
|
def predict(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = feature_extractor(images=image, return_tensors="pt").to("cuda") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class = logits.argmax(-1).item() |
|
return dataset["train"].features["fine_label"].int2str(predicted_class) |
|
|
|
print(predict("sample_image.jpg")) |
|
``` |
|
|
|
# 📁 Folder Structure |
|
|
|
📦image-classification-vit |
|
┣ 📂vit-cifar100-fp16 |
|
┣ 📜train.py |
|
┣ 📜inference.py |
|
┣ 📜README.md |
|
┗ 📜requirements.txt |
|
|
|
|
|
|