developerPushkal's picture
Update README.md
29b2ecd verified
# ResNet-50 Fine-Tuned Model for Vehicle Type Classification
This repository hosts a **fine-tuned ResNet-50 model** for **Vehicle Type Classification**, trained on a subset of the **MIO-TCD Traffic Dataset**. This model is designed for **traffic management applications**, enabling real-time and accurate recognition of different vehicle types, such as cars, trucks, buses, and motorcycles.
## Model Details
- **Model Architecture:** ResNet-50
- **Task:** Vehicle Type Classification
- **Dataset:** MIO-TCD (Subset from Kaggle: `miotcd-dataset-50000-imagesclassification`)
- **Number of Classes:** 11 vehicle categories
- **Fine-tuning Framework:** PyTorch (`torchvision.models.resnet50`)
- **Optimization:** Trained with Adam optimizer and data augmentation for robust performance
## Usage
### Installation
Ensure you have the required dependencies installed:
```sh
pip install torch torchvision pillow
```
### Loading the Model
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# Define the model architecture
resnet50 = models.resnet50(pretrained=False)
# Modify the last layer to match the number of classes (11)
num_ftrs = resnet50.fc.in_features
resnet50.fc = torch.nn.Linear(num_ftrs, 11)
# Load trained model weights
resnet50.load_state_dict(torch.load("fine_tuned_model/pytorch_model.bin"))
resnet50.eval() # Set model to evaluation mode
print("Model loaded successfully!")
# Load class names
with open("fine_tuned_model/classes.txt", "r") as f:
class_names = f.read().splitlines()
print("Classes:", class_names)
# Define image transformations (same as training)
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match ResNet-50 input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalization
])
# Load the custom image
image_path = "/kaggle/input/sample-image-1/pickup_truck_sample_image.jpg" # Change this to your test image path
image = Image.open(image_path).convert("RGB") # Open image and convert to RGB
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50 = resnet50.to(device)
input_tensor = input_tensor.to(device)
# Get predictions
with torch.no_grad():
outputs = resnet50(input_tensor)
_, predicted_class = torch.max(outputs, 1) # Get the class with highest score
# Print the result
print(f"Predicted Vehicle Type: {class_names[predicted_class.item()]}")
```
## Performance Metrics
- **Validation Accuracy:** High accuracy achieved on the test dataset
- **Inference Speed:** Optimized for real-time classification
- **Robustness:** Trained with data augmentation to handle variations in lighting and angles
## Dataset Details
The dataset consists of **50,000 images** across **11 vehicle types**, structured in the following folders:
- **articulated_truck**
- **bicycle**
- **bus**
- **car**
- **motorcycle**
- **non-motorized_vehicle**
- **pedestrian**
- **pickup_truck**
- **single_unit_truck**
- **work_van**
- **unknown**
### Training Details
- **Number of Epochs:** 10
- **Batch Size:** 32
- **Optimizer:** Adam
- **Learning Rate:** 1e-4
- **Loss Function:** Cross-Entropy Loss
- **Data Augmentation:** Horizontal flipping, random cropping, normalization
## Repository Structure
```
.
β”œβ”€β”€ fine_tuned_model/ # Contains the fine-tuned model files
β”‚ β”œβ”€β”€ pytorch_model.bin # Model weights
β”‚ β”œβ”€β”€ classes.txt # Class labels
β”œβ”€β”€ dataset/ # Training dataset (MIO-TCD subset)
β”œβ”€β”€ scripts/ # Training and evaluation scripts
β”œβ”€β”€ README.md # Model documentation
```
## Limitations
- The model is trained specifically on the **MIO-TCD dataset** and may not generalize well to images from different sources.
- Accuracy may vary based on real-world conditions such as lighting, occlusion, and camera angles.
- Requires GPU for faster inference.
## Contributing
Contributions are welcome! If you have suggestions for improvement, feel free to submit a pull request or open an issue.