Final_Project / app.py
Dgv2's picture
Create app.py
6e488f9 verified
raw
history blame
977 Bytes
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import gradio as gr
image_processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
def classify_dog(image):
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_breed = model.config.id2label[predicted_class_idx]
return f"Predicted Dog Breed: {predicted_breed}"
demo = gr.Interface(
fn=classify_dog,
inputs=gr.Image(type="pil"),
outputs="text",
title="Dog Breed Classifier",
description="Upload an image of a dog and the model will classify its breed (120 breeds supported)."
)
demo.launch()