MoinulwithAI commited on
Commit
37b7c68
·
verified ·
1 Parent(s): 37f6d64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Load the trained model
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = YourModel() # Replace 'YourModel' with your actual model class
10
+ model.load_state_dict(torch.load('D:/Dataset/Cricket Bowl Grip/final_model.pth'))
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # Define the transformation to be applied to the input image
15
+ transform = transforms.Compose([
16
+ transforms.Resize((224, 224)), # Resize image to fit your model's input size
17
+ transforms.ToTensor(), # Convert image to tensor
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalize values for ImageNet
19
+ ])
20
+
21
+ # Define a function for making predictions
22
+ def predict(image):
23
+ image = Image.fromarray(image) # Convert numpy array to PIL Image
24
+ image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
25
+ image = image.to(device)
26
+
27
+ with torch.no_grad():
28
+ outputs = model(image)
29
+ _, predicted = torch.max(outputs, 1)
30
+
31
+ # Map predicted label to class name
32
+ class_names = ['OUTSWING', 'STRAIGHT', 'BACK_OF_HAND', 'CARROM', 'CROSSSEAM',
33
+ 'GOOGLY', 'INSWING', 'KNUCKLE', 'LEGSPIN', 'OFFSPIN']
34
+ predicted_label = class_names[predicted.item()]
35
+
36
+ return predicted_label
37
+
38
+ # Create the Gradio Interface
39
+ iface = gr.Interface(fn=predict,
40
+ inputs=gr.Image(type="numpy"), # Accepts image input
41
+ outputs=gr.Text(), # Output the predicted class label
42
+ live=True) # live=True enables prediction while image is being uploaded
43
+
44
+ # Launch the interface
45
+ iface.launch()