|
from transformers import AutoModelForImageClassification, AutoFeatureExtractor |
|
import gradio as gr |
|
|
|
import torch |
|
|
|
def classify_image(image): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = AutoModelForImageClassification.from_pretrained('glopez/cifar-10').to(device) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained('glopez/cifar-10') |
|
inp = feature_extractor(image, return_tensors='pt').to(device) |
|
outp = model(**inp) |
|
pred = torch.argmax(outp.logits, dim=1).item() |
|
return model.config.id2label[pred] |
|
|
|
interface = gr.Interface(fn=classify_image, inputs=gr.Image(shape=(224, 224)), outputs="text").launch(debug=True) |
|
|