Dgv2 commited on
Commit
7c46844
·
verified ·
1 Parent(s): b4808e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -3,24 +3,27 @@ from PIL import Image
3
  import torch
4
  import gradio as gr
5
 
6
- image_processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
7
- model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
 
8
 
9
- def classify_dog(image):
10
- inputs = image_processor(images=image, return_tensors="pt")
 
11
  with torch.no_grad():
12
- outputs = model(**inputs)
13
- logits = outputs.logits
14
- predicted_class_idx = logits.argmax(-1).item()
15
- predicted_breed = model.config.id2label[predicted_class_idx]
16
- return f"Predicted Dog Breed: {predicted_breed}"
17
 
18
- demo = gr.Interface(
19
- fn=classify_dog,
 
20
  inputs=gr.Image(type="pil"),
21
  outputs="text",
22
- title="Dog Breed Classifier",
23
- description="Upload an image of a dog and the model will classify its breed (120 breeds supported)."
24
  )
25
 
26
- demo.launch()
 
3
  import torch
4
  import gradio as gr
5
 
6
+ # Load the image processor and model from Hugging Face
7
+ processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
8
+ breed_model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
9
 
10
+ # This function takes an uploaded image and returns the predicted dog breed
11
+ def detect_breed(img):
12
+ inputs = processor(images=img, return_tensors="pt")
13
  with torch.no_grad():
14
+ result = breed_model(**inputs)
15
+ predictions = result.logits
16
+ top_prediction = predictions.argmax(dim=1).item()
17
+ breed_name = breed_model.config.id2label[top_prediction]
18
+ return f"This looks like a {breed_name}!"
19
 
20
+ # Set up the Gradio web interface
21
+ app = gr.Interface(
22
+ fn=detect_breed,
23
  inputs=gr.Image(type="pil"),
24
  outputs="text",
25
+ title="Dog Breed Identifier 🐶",
26
+ description="Upload a photo of a dog and find out what breed it is! The model can recognize 120 different dog breeds."
27
  )
28
 
29
+ app.launch()