developerPushkal commited on
Commit
29b2ecd
·
verified ·
1 Parent(s): f27d0c4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -27
README.md CHANGED
@@ -11,14 +11,6 @@ This repository hosts a **fine-tuned ResNet-50 model** for **Vehicle Type Classi
11
  - **Fine-tuning Framework:** PyTorch (`torchvision.models.resnet50`)
12
  - **Optimization:** Trained with Adam optimizer and data augmentation for robust performance
13
 
14
- ## Downloading the Model
15
-
16
- You can download the fine-tuned model from the provided link:
17
-
18
- ```sh
19
- wget <download_link>/fine_tuned_model.zip
20
- unzip fine_tuned_model.zip
21
- ```
22
 
23
  ## Usage
24
 
@@ -38,37 +30,54 @@ import torchvision.models as models
38
  import torchvision.transforms as transforms
39
  from PIL import Image
40
 
41
- # Load the model architecture
42
- model = models.resnet50(pretrained=False)
43
- num_ftrs = model.fc.in_features
44
- model.fc = torch.nn.Linear(num_ftrs, 11) # 11 vehicle classes
 
 
 
 
 
 
 
 
45
 
46
- # Load fine-tuned weights
47
- model.load_state_dict(torch.load("fine_tuned_model/pytorch_model.bin", map_location=torch.device('cpu')))
48
- model.eval() # Set to evaluation mode
49
 
50
- # Load class labels
51
  with open("fine_tuned_model/classes.txt", "r") as f:
52
  class_names = f.read().splitlines()
53
 
54
- # Define preprocessing transformations
 
 
 
55
  transform = transforms.Compose([
56
- transforms.Resize((224, 224)),
57
- transforms.ToTensor(),
58
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
  ])
60
 
61
- # Load and preprocess a test image
62
- image_path = "path_to_your_image.jpg" # Change this to your test image path
63
- image = Image.open(image_path).convert("RGB")
64
- input_tensor = transform(image).unsqueeze(0)
 
65
 
66
- # Make prediction
 
 
 
 
 
 
67
  with torch.no_grad():
68
- outputs = model(input_tensor)
69
- _, predicted_class = torch.max(outputs, 1)
70
 
 
71
  print(f"Predicted Vehicle Type: {class_names[predicted_class.item()]}")
 
72
  ```
73
 
74
  ## Performance Metrics
 
11
  - **Fine-tuning Framework:** PyTorch (`torchvision.models.resnet50`)
12
  - **Optimization:** Trained with Adam optimizer and data augmentation for robust performance
13
 
 
 
 
 
 
 
 
 
14
 
15
  ## Usage
16
 
 
30
  import torchvision.transforms as transforms
31
  from PIL import Image
32
 
33
+ # Define the model architecture
34
+ resnet50 = models.resnet50(pretrained=False)
35
+
36
+ # Modify the last layer to match the number of classes (11)
37
+ num_ftrs = resnet50.fc.in_features
38
+ resnet50.fc = torch.nn.Linear(num_ftrs, 11)
39
+
40
+ # Load trained model weights
41
+ resnet50.load_state_dict(torch.load("fine_tuned_model/pytorch_model.bin"))
42
+ resnet50.eval() # Set model to evaluation mode
43
+
44
+ print("Model loaded successfully!")
45
 
 
 
 
46
 
47
+ # Load class names
48
  with open("fine_tuned_model/classes.txt", "r") as f:
49
  class_names = f.read().splitlines()
50
 
51
+ print("Classes:", class_names)
52
+
53
+
54
+ # Define image transformations (same as training)
55
  transform = transforms.Compose([
56
+ transforms.Resize((224, 224)), # Resize to match ResNet-50 input size
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalization
59
  ])
60
 
61
+ # Load the custom image
62
+ image_path = "/kaggle/input/sample-image-1/pickup_truck_sample_image.jpg" # Change this to your test image path
63
+ image = Image.open(image_path).convert("RGB") # Open image and convert to RGB
64
+ input_tensor = transform(image).unsqueeze(0) # Add batch dimension
65
+
66
 
67
+
68
+ # Move to GPU if available
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ resnet50 = resnet50.to(device)
71
+ input_tensor = input_tensor.to(device)
72
+
73
+ # Get predictions
74
  with torch.no_grad():
75
+ outputs = resnet50(input_tensor)
76
+ _, predicted_class = torch.max(outputs, 1) # Get the class with highest score
77
 
78
+ # Print the result
79
  print(f"Predicted Vehicle Type: {class_names[predicted_class.item()]}")
80
+
81
  ```
82
 
83
  ## Performance Metrics