Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +22 -7
- app.py +152 -0
- examples/apple_pie.jpg +3 -0
- examples/pizza.jpg +0 -0
- examples/sushi.jpg +0 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/apple_pie.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,14 +1,29 @@
|
|
1 |
---
|
2 |
-
title: Food Classifier Comparison
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
-
short_description: Comparing two different model speed and accuracy
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Food Classifier with Model Comparison
|
3 |
+
emoji: π
|
4 |
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
# π Food Classifier: Accuracy vs. Speed
|
13 |
+
|
14 |
+
This Gradio demo allows you to classify food images using two different transformer-based models and visually compare their performance.
|
15 |
+
|
16 |
+
## How to Use
|
17 |
+
|
18 |
+
1. **Upload an Image**: Drag and drop a food image or click to upload one. You can also use one of the examples below.
|
19 |
+
2. **Choose a Model**: Select either the ViT or Swin model from the dropdown.
|
20 |
+
3. **Click Classify**: The model will predict the food item.
|
21 |
+
|
22 |
+
## The Comparison Feature
|
23 |
+
|
24 |
+
The key feature of this demo is the **performance comparison chart**:
|
25 |
+
|
26 |
+
- **Benchmark Accuracy**: This chart shows the reported accuracy of each model on the Food101 test set. The Swin model is generally more accurate.
|
27 |
+
- **Inference Time**: This chart shows the *actual time* it took for the selected model to process *your* uploaded image. You can see the speed trade-off firsthand. The ViT model is often faster.
|
28 |
+
|
29 |
+
This allows you to understand the classic machine learning trade-off between a model's accuracy and its computational cost (speed).
|
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
from transformers import pipeline
|
6 |
+
from PIL import Image
|
7 |
+
import pandas as pd
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import io
|
10 |
+
|
11 |
+
# --- 1. Model Configuration & Metadata ---
|
12 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
MODEL_INFO = {
|
15 |
+
"ViT (eslamxm/vit-base-food101)": {
|
16 |
+
"model_id": "eslamxm/vit-base-food101",
|
17 |
+
"benchmark_accuracy": 90.68,
|
18 |
+
"pipeline": None
|
19 |
+
},
|
20 |
+
"Swin (aspis/swin-finetuned-food101)": {
|
21 |
+
"model_id": "aspis/swin-finetuned-food101",
|
22 |
+
"benchmark_accuracy": 93.81,
|
23 |
+
"pipeline": None
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
# --- 2. Lazy Loading of Models ---
|
28 |
+
def load_pipeline(model_name):
|
29 |
+
"""Loads a model pipeline only when it's first needed."""
|
30 |
+
if MODEL_INFO[model_name]["pipeline"] is None:
|
31 |
+
print(f"Loading model: {model_name}...")
|
32 |
+
model_id = MODEL_INFO[model_name]["model_id"]
|
33 |
+
MODEL_INFO[model_name]["pipeline"] = pipeline(task="image-classification", model=model_id, device=DEVICE)
|
34 |
+
print(f"Model '{model_name}' loaded on {DEVICE}.")
|
35 |
+
return MODEL_INFO[model_name]["pipeline"]
|
36 |
+
|
37 |
+
# --- 3. Function to Generate Comparison Chart ---
|
38 |
+
def create_comparison_chart(selected_model_name, current_inference_time):
|
39 |
+
"""Generates a bar chart comparing model accuracy and inference time."""
|
40 |
+
data = {'Model': [], 'Metric': [], 'Value': []}
|
41 |
+
for name, info in MODEL_INFO.items():
|
42 |
+
data['Model'].append(name)
|
43 |
+
data['Metric'].append('Benchmark Accuracy (%)')
|
44 |
+
data['Value'].append(info['benchmark_accuracy'])
|
45 |
+
|
46 |
+
data['Model'].append(selected_model_name)
|
47 |
+
data['Metric'].append('Current Inference Time (s)')
|
48 |
+
data['Value'].append(current_inference_time)
|
49 |
+
|
50 |
+
df = pd.DataFrame(data)
|
51 |
+
|
52 |
+
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
|
53 |
+
fig.suptitle('Model Performance Comparison', fontsize=16)
|
54 |
+
|
55 |
+
acc_df = df[df['Metric'] == 'Benchmark Accuracy (%)']
|
56 |
+
colors_acc = ['#4c72b0' if model != selected_model_name else '#2ca02c' for model in acc_df['Model']]
|
57 |
+
acc_plot = acc_df.plot(kind='bar', x='Model', y='Value', ax=ax[0], color=colors_acc, legend=None)
|
58 |
+
ax[0].set_title('Benchmark Accuracy')
|
59 |
+
ax[0].set_ylabel('Accuracy (%)')
|
60 |
+
ax[0].set_xlabel('')
|
61 |
+
ax[0].set_ylim(0, 100)
|
62 |
+
ax[0].tick_params(axis='x', rotation=10)
|
63 |
+
for p in acc_plot.patches:
|
64 |
+
ax[0].annotate(f"{p.get_height():.2f}%", (p.get_x() + p.get_width() / 2., p.get_height()),
|
65 |
+
ha='center', va='center', xytext=(0, 9), textcoords='offset points')
|
66 |
+
|
67 |
+
time_df = df[df['Metric'] == 'Current Inference Time (s)']
|
68 |
+
time_plot = time_df.plot(kind='bar', x='Model', y='Value', ax=ax[1], color=['#d62728'])
|
69 |
+
ax[1].set_title('Inference Time for This Image')
|
70 |
+
ax[1].set_ylabel('Time (seconds)')
|
71 |
+
ax[1].set_xlabel('')
|
72 |
+
ax[1].tick_params(axis='x', rotation=0)
|
73 |
+
for p in time_plot.patches:
|
74 |
+
ax[1].annotate(f"{p.get_height():.4f}s", (p.get_x() + p.get_width() / 2., p.get_height()),
|
75 |
+
ha='center', va='center', xytext=(0, 9), textcoords='offset points')
|
76 |
+
|
77 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
78 |
+
return fig
|
79 |
+
|
80 |
+
# --- 4. The Core Classification Function ---
|
81 |
+
def classify_image(image, model_name):
|
82 |
+
"""
|
83 |
+
Takes an image and model name, returns predictions, inference time,
|
84 |
+
and a comparison chart.
|
85 |
+
"""
|
86 |
+
if image is None:
|
87 |
+
return {}, "Please upload an image first.", None, "Please upload an image to see a comparison."
|
88 |
+
|
89 |
+
pipe = load_pipeline(model_name)
|
90 |
+
start_time = time.time()
|
91 |
+
predictions = pipe(Image.fromarray(image))
|
92 |
+
end_time = time.time()
|
93 |
+
|
94 |
+
inference_time = end_time - start_time
|
95 |
+
|
96 |
+
top_5_preds = {p['label'].replace("_", " ").title(): p['score'] for p in predictions[:5]}
|
97 |
+
comparison_fig = create_comparison_chart(model_name, inference_time)
|
98 |
+
|
99 |
+
buf = io.BytesIO()
|
100 |
+
comparison_fig.savefig(buf, format='png', bbox_inches='tight')
|
101 |
+
buf.seek(0)
|
102 |
+
comparison_img = Image.open(buf)
|
103 |
+
plt.close(comparison_fig)
|
104 |
+
|
105 |
+
return (
|
106 |
+
top_5_preds,
|
107 |
+
f"Inference Time: {inference_time:.4f} seconds",
|
108 |
+
comparison_img,
|
109 |
+
f"Chart shows accuracy for all models and the inference time for the **{model_name}** model on this specific image."
|
110 |
+
)
|
111 |
+
|
112 |
+
# --- 5. Gradio Interface Definition ---
|
113 |
+
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
114 |
+
gr.Markdown("# π Food Classifier: Accuracy vs. Speed")
|
115 |
+
gr.Markdown(
|
116 |
+
"Compare two different models for classifying food images from the Food101 dataset. "
|
117 |
+
"Notice the trade-off: the **Swin** model is more accurate but might be slower, while the **ViT** model is faster but slightly less accurate."
|
118 |
+
)
|
119 |
+
|
120 |
+
with gr.Row(variant="panel"):
|
121 |
+
with gr.Column(scale=1):
|
122 |
+
image_input = gr.Image(type="numpy", label="Upload a food picture")
|
123 |
+
model_dropdown = gr.Dropdown(
|
124 |
+
choices=list(MODEL_INFO.keys()),
|
125 |
+
value=list(MODEL_INFO.keys())[0],
|
126 |
+
label="Choose a Model"
|
127 |
+
)
|
128 |
+
classify_button = gr.Button("Classify Image", variant="primary")
|
129 |
+
|
130 |
+
gr.Examples(
|
131 |
+
examples=[
|
132 |
+
["examples/sushi.jpg", list(MODEL_INFO.keys())[1]],
|
133 |
+
["examples/pizza.jpg", list(MODEL_INFO.keys())[0]],
|
134 |
+
["examples/apple_pie.jpg", list(MODEL_INFO.keys())[1]],
|
135 |
+
],
|
136 |
+
inputs=[image_input, model_dropdown],
|
137 |
+
)
|
138 |
+
|
139 |
+
with gr.Column(scale=2):
|
140 |
+
output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions")
|
141 |
+
output_time = gr.Textbox(label="Performance")
|
142 |
+
output_chart = gr.Image(type="pil", label="Model Comparison Chart")
|
143 |
+
chart_info = gr.Markdown()
|
144 |
+
|
145 |
+
classify_button.click(
|
146 |
+
fn=classify_image,
|
147 |
+
inputs=[image_input, model_dropdown],
|
148 |
+
outputs=[output_label, output_time, output_chart, chart_info]
|
149 |
+
)
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
demo.launch()
|
examples/apple_pie.jpg
ADDED
![]() |
Git LFS Details
|
examples/pizza.jpg
ADDED
![]() |
examples/sushi.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.19.2
|
2 |
+
transformers==4.38.1
|
3 |
+
torch==2.1.2
|
4 |
+
torchvision==0.16.2
|
5 |
+
pandas==2.1.4
|
6 |
+
matplotlib==3.8.0
|
7 |
+
accelerate==0.27.2
|
8 |
+
Pillow==10.2.0
|