File size: 3,581 Bytes
37b7c68
 
b0f54e3
31c3d5a
 
b0f54e3
37b7c68
31c3d5a
b0f54e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c3d5a
 
 
b0f54e3
31c3d5a
 
37b7c68
 
 
31c3d5a
37b7c68
b0f54e3
 
 
37b7c68
 
31c3d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b7c68
31c3d5a
 
 
 
 
37b7c68
31c3d5a
 
 
 
 
37b7c68
 
31c3d5a
 
 
 
 
 
 
37b7c68
31c3d5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os

# Define the same custom residual block and EfficientNetWithNovelty model
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Skip connection
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.skip_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = self.skip(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += identity  # Add skip connection
        x = self.relu(x)
        return x

class EfficientNetWithNovelty(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNetWithNovelty, self).__init__()

        # Load pre-trained EfficientNet-B0 model
        self.model = models.efficientnet_b0(pretrained=True)

        # Modify the final classifier layer for our number of classes
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)

        # Add the custom residual block after the EfficientNet feature extractor
        self.residual_block = ResidualBlock(1280, 1280)  # 1280 is the output channels from EfficientNet B0

    def forward(self, x):
        # Pass through the EfficientNet feature extractor
        x = self.model.features(x)  # Access feature extraction part
        
        # Pass through the custom residual block
        x = self.residual_block(x)
        
        # Flatten the output to feed into the classifier
        x = x.mean([2, 3])  # Global Average Pooling
        x = self.model.classifier(x)  # Pass through the final classifier layer
        
        return x

# Load the model checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10  # Number of classes as per your dataset
model = EfficientNetWithNovelty(num_classes)
checkpoint = torch.load('best_model2.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define the class labels explicitly
class_labels = [
    "KNUCKLE",
    "LEGSPIN",
    "OFFSPIN",
    "OUTSWING",
    "STRAIGHT",
    "BACK_OF_HAND",
    "CARROM",
    "CROSSSEAM",
    "GOOGLY",
    "INSWING"
]

# Prediction function
def predict(image):
    # Preprocess image
    image = Image.fromarray(image)  # Convert numpy array to PIL Image if it's from Gradio
    image = transform(image).unsqueeze(0).to(device)
    
    # Get model predictions
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    # Get predicted class label
    predicted_label = class_labels[predicted.item()]
    return predicted_label

# Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Upload Cricket Grip Image"),
    outputs=gr.Textbox(label="Predicted Grip Type"),
    live=True
)

if __name__ == "__main__":
    iface.launch()