File size: 669 Bytes
b947cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d3347
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import gradio as gr

import torch

def classify_image(image):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = AutoModelForImageClassification.from_pretrained('glopez/cifar-10').to(device)
  feature_extractor = AutoFeatureExtractor.from_pretrained('glopez/cifar-10')  
  inp = feature_extractor(image, return_tensors='pt').to(device)
  outp = model(**inp)
  pred = torch.argmax(outp.logits, dim=1).item()
  return model.config.id2label[pred]

interface = gr.Interface(fn=classify_image, inputs=gr.Image(shape=(224, 224)), outputs="text").launch(debug=True)