Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import models,transforms | |
from PIL import Image | |
import gradio as gr | |
from torchvision.transforms import transforms | |
# model=models.resnet18(pretrained=True) | |
# model.fc=nn.Linear(model.fc.in_features,10) | |
t=transforms.Compose([ transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), | |
transforms.RandomHorizontalFlip(0.5), | |
transforms.RandomRotation(10), | |
]) | |
class_name=[f"c{i}" for i in range(1,10)] | |
model=torch.load("model.pth") | |
print(model) | |
def predict(image): | |
image=t(image).unsqueeze(0) | |
with torch.no_grad(): | |
output=model(image) | |
_,predicted=torch.max(output,1) | |
predicted_class=predicted.item() | |
return predicted_class | |
interface=gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="cifar dataset prediction", | |
description="upload an image to get its class" | |
) | |
interface.launch(share=True) |