MoinulwithAI's picture
Update app.py
31c3d5a verified
raw
history blame
3.58 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
# Define the same custom residual block and EfficientNetWithNovelty model
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# Skip connection
self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.skip_bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
identity = self.skip(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x += identity # Add skip connection
x = self.relu(x)
return x
class EfficientNetWithNovelty(nn.Module):
def __init__(self, num_classes):
super(EfficientNetWithNovelty, self).__init__()
# Load pre-trained EfficientNet-B0 model
self.model = models.efficientnet_b0(pretrained=True)
# Modify the final classifier layer for our number of classes
self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
# Add the custom residual block after the EfficientNet feature extractor
self.residual_block = ResidualBlock(1280, 1280) # 1280 is the output channels from EfficientNet B0
def forward(self, x):
# Pass through the EfficientNet feature extractor
x = self.model.features(x) # Access feature extraction part
# Pass through the custom residual block
x = self.residual_block(x)
# Flatten the output to feed into the classifier
x = x.mean([2, 3]) # Global Average Pooling
x = self.model.classifier(x) # Pass through the final classifier layer
return x
# Load the model checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10 # Number of classes as per your dataset
model = EfficientNetWithNovelty(num_classes)
checkpoint = torch.load('best_model2.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# Define image transformations for preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the class labels explicitly
class_labels = [
"KNUCKLE",
"LEGSPIN",
"OFFSPIN",
"OUTSWING",
"STRAIGHT",
"BACK_OF_HAND",
"CARROM",
"CROSSSEAM",
"GOOGLY",
"INSWING"
]
# Prediction function
def predict(image):
# Preprocess image
image = Image.fromarray(image) # Convert numpy array to PIL Image if it's from Gradio
image = transform(image).unsqueeze(0).to(device)
# Get model predictions
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
# Get predicted class label
predicted_label = class_labels[predicted.item()]
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Cricket Grip Image"),
outputs=gr.Textbox(label="Predicted Grip Type"),
live=True
)
if __name__ == "__main__":
iface.launch()