DimaML's picture
Update app.py
a7d5366 verified
raw
history blame
1.08 kB
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
device = torch.device("cpu")
class_names = ['Anger', 'Disgust', 'Fear', 'Happy', 'Pain', 'Sad']
classes_count = len(class_names)
model = model.renset18(weights='DEFAULT').to(device)
model.fc = nn.Sequential(
nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=device), strict=False)
def predict(image):
image = transformer(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
pred = torch.softmax(model(image), dim=1)
preds_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}
return preds_and_labels
app = gr.Interface(
predict,
gr.Image(type='pil'),
gr.Label(label='Predictions', num_top_classes=classes_count),
#examples=[
# './example1.jpg',
# './example2.jpg',
# './example3.jpg',
#],
live=True
)
app.launch()