FirstDeploy / app.py
SkullFaceFire's picture
Create app.py
65602f1 verified
raw
history blame
1.04 kB
import torch
import torch.nn as nn
from torchvision import models,transforms
from PIL import Image
import gradio as gr
from torchvision.transforms import transforms
# model=models.resnet18(pretrained=True)
# model.fc=nn.Linear(model.fc.in_features,10)
t=transforms.Compose([ transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(10),
])
class_name=[f"c{i}" for i in range(1,10)]
model=torch.load("model.pth")
print(model)
def predict(image):
image=t(image).unsqueeze(0)
with torch.no_grad():
output=model(image)
_,predicted=torch.max(output,1)
predicted_class=predicted.item()
return predicted_class
interface=gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="cifar dataset prediction",
description="upload an image to get its class"
)
interface.launch(share=True)