YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

ResNet-50 Fine-Tuned Model for Vehicle Type Classification

This repository hosts a fine-tuned ResNet-50 model for Vehicle Type Classification, trained on a subset of the MIO-TCD Traffic Dataset. This model is designed for traffic management applications, enabling real-time and accurate recognition of different vehicle types, such as cars, trucks, buses, and motorcycles.

Model Details

  • Model Architecture: ResNet-50
  • Task: Vehicle Type Classification
  • Dataset: MIO-TCD (Subset from Kaggle: miotcd-dataset-50000-imagesclassification)
  • Number of Classes: 11 vehicle categories
  • Fine-tuning Framework: PyTorch (torchvision.models.resnet50)
  • Optimization: Trained with Adam optimizer and data augmentation for robust performance

Usage

Installation

Ensure you have the required dependencies installed:

pip install torch torchvision pillow

Loading the Model

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Define the model architecture
resnet50 = models.resnet50(pretrained=False)

# Modify the last layer to match the number of classes (11)
num_ftrs = resnet50.fc.in_features
resnet50.fc = torch.nn.Linear(num_ftrs, 11)

# Load trained model weights
resnet50.load_state_dict(torch.load("fine_tuned_model/pytorch_model.bin"))
resnet50.eval()  # Set model to evaluation mode

print("Model loaded successfully!")


# Load class names
with open("fine_tuned_model/classes.txt", "r") as f:
    class_names = f.read().splitlines()

print("Classes:", class_names)


# Define image transformations (same as training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match ResNet-50 input size
    transforms.ToTensor(),  
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalization
])

# Load the custom image
image_path = "/kaggle/input/sample-image-1/pickup_truck_sample_image.jpg"  # Change this to your test image path
image = Image.open(image_path).convert("RGB")  # Open image and convert to RGB
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension



# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50 = resnet50.to(device)
input_tensor = input_tensor.to(device)

# Get predictions
with torch.no_grad():
    outputs = resnet50(input_tensor)
    _, predicted_class = torch.max(outputs, 1)  # Get the class with highest score

# Print the result
print(f"Predicted Vehicle Type: {class_names[predicted_class.item()]}")

Performance Metrics

  • Validation Accuracy: High accuracy achieved on the test dataset
  • Inference Speed: Optimized for real-time classification
  • Robustness: Trained with data augmentation to handle variations in lighting and angles

Dataset Details

The dataset consists of 50,000 images across 11 vehicle types, structured in the following folders:

  • articulated_truck
  • bicycle
  • bus
  • car
  • motorcycle
  • non-motorized_vehicle
  • pedestrian
  • pickup_truck
  • single_unit_truck
  • work_van
  • unknown

Training Details

  • Number of Epochs: 10
  • Batch Size: 32
  • Optimizer: Adam
  • Learning Rate: 1e-4
  • Loss Function: Cross-Entropy Loss
  • Data Augmentation: Horizontal flipping, random cropping, normalization

Repository Structure

.
β”œβ”€β”€ fine_tuned_model/      # Contains the fine-tuned model files
β”‚   β”œβ”€β”€ pytorch_model.bin  # Model weights
β”‚   β”œβ”€β”€ classes.txt        # Class labels
β”œβ”€β”€ dataset/               # Training dataset (MIO-TCD subset)
β”œβ”€β”€ scripts/               # Training and evaluation scripts
β”œβ”€β”€ README.md              # Model documentation

Limitations

  • The model is trained specifically on the MIO-TCD dataset and may not generalize well to images from different sources.
  • Accuracy may vary based on real-world conditions such as lighting, occlusion, and camera angles.
  • Requires GPU for faster inference.

Contributing

Contributions are welcome! If you have suggestions for improvement, feel free to submit a pull request or open an issue.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support