|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(ResidualBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU() |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
|
|
self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
self.skip_bn = nn.BatchNorm2d(out_channels) |
|
|
|
def forward(self, x): |
|
identity = self.skip(x) |
|
x = self.relu(self.bn1(self.conv1(x))) |
|
x = self.bn2(self.conv2(x)) |
|
x += identity |
|
x = self.relu(x) |
|
return x |
|
|
|
class EfficientNetWithNovelty(nn.Module): |
|
def __init__(self, num_classes): |
|
super(EfficientNetWithNovelty, self).__init__() |
|
|
|
|
|
self.model = models.efficientnet_b0(pretrained=True) |
|
|
|
|
|
self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes) |
|
|
|
|
|
self.residual_block = ResidualBlock(1280, 1280) |
|
|
|
def forward(self, x): |
|
|
|
x = self.model.features(x) |
|
|
|
|
|
x = self.residual_block(x) |
|
|
|
|
|
x = x.mean([2, 3]) |
|
x = self.model.classifier(x) |
|
|
|
return x |
|
|
|
|
|
device = torch.device('cpu') |
|
num_classes = 10 |
|
model = EfficientNetWithNovelty(num_classes) |
|
checkpoint = torch.load('best_model2.pth', map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
class_labels = [ |
|
"KNUCKLE", |
|
"LEGSPIN", |
|
"OFFSPIN", |
|
"OUTSWING", |
|
"STRAIGHT", |
|
"BACK_OF_HAND", |
|
"CARROM", |
|
"CROSSSEAM", |
|
"GOOGLY", |
|
"INSWING" |
|
] |
|
|
|
|
|
def predict(image): |
|
|
|
image = Image.fromarray(image) |
|
image = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
_, predicted = torch.max(output, 1) |
|
|
|
|
|
predicted_label = class_labels[predicted.item()] |
|
return predicted_label |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="numpy", label="Upload Cricket Grip Image"), |
|
outputs=gr.Textbox(label="Predicted Grip Type"), |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|