File size: 669 Bytes
b947cf2 c5d3347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import gradio as gr
import torch
def classify_image(image):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForImageClassification.from_pretrained('glopez/cifar-10').to(device)
feature_extractor = AutoFeatureExtractor.from_pretrained('glopez/cifar-10')
inp = feature_extractor(image, return_tensors='pt').to(device)
outp = model(**inp)
pred = torch.argmax(outp.logits, dim=1).item()
return model.config.id2label[pred]
interface = gr.Interface(fn=classify_image, inputs=gr.Image(shape=(224, 224)), outputs="text").launch(debug=True)
|