paddy / app.py
masapasa's picture
Create new file
78e7c40
raw
history blame
1.76 kB
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random
device = torch.device('cpu')
labels = {
0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'
}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image=None) -> dict:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("/home/aswin/Downloads/paddy_model.pth", map_location=torch.device(device)))
model.to(device)
predicted = inference_fn(model, image)
return {labels[i]: float(predicted[i]) for i in range(10)}
interface = gr.Interface(fn=predict,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Label(num_top_classes=10),
interpretation='default').launch()
interface.launch()