|
|
|
import gradio as gr |
|
import time |
|
import torch |
|
from transformers import pipeline |
|
from PIL import Image |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import io |
|
|
|
|
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
MODEL_INFO = { |
|
"ViT (eslamxm/vit-base-food101)": { |
|
"model_id": "eslamxm/vit-base-food101", |
|
"benchmark_accuracy": 90.68, |
|
"pipeline": None |
|
}, |
|
"Swin (aspis/swin-finetuned-food101)": { |
|
"model_id": "aspis/swin-finetuned-food101", |
|
"benchmark_accuracy": 93.81, |
|
"pipeline": None |
|
} |
|
} |
|
|
|
|
|
def load_pipeline(model_name): |
|
"""Loads a model pipeline only when it's first needed.""" |
|
if MODEL_INFO[model_name]["pipeline"] is None: |
|
print(f"Loading model: {model_name}...") |
|
model_id = MODEL_INFO[model_name]["model_id"] |
|
MODEL_INFO[model_name]["pipeline"] = pipeline(task="image-classification", model=model_id, device=DEVICE) |
|
print(f"Model '{model_name}' loaded on {DEVICE}.") |
|
return MODEL_INFO[model_name]["pipeline"] |
|
|
|
|
|
def create_comparison_chart(selected_model_name, current_inference_time): |
|
"""Generates a bar chart comparing model accuracy and inference time.""" |
|
data = {'Model': [], 'Metric': [], 'Value': []} |
|
for name, info in MODEL_INFO.items(): |
|
data['Model'].append(name) |
|
data['Metric'].append('Benchmark Accuracy (%)') |
|
data['Value'].append(info['benchmark_accuracy']) |
|
|
|
data['Model'].append(selected_model_name) |
|
data['Metric'].append('Current Inference Time (s)') |
|
data['Value'].append(current_inference_time) |
|
|
|
df = pd.DataFrame(data) |
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(12, 5)) |
|
fig.suptitle('Model Performance Comparison', fontsize=16) |
|
|
|
acc_df = df[df['Metric'] == 'Benchmark Accuracy (%)'] |
|
colors_acc = ['#4c72b0' if model != selected_model_name else '#2ca02c' for model in acc_df['Model']] |
|
acc_plot = acc_df.plot(kind='bar', x='Model', y='Value', ax=ax[0], color=colors_acc, legend=None) |
|
ax[0].set_title('Benchmark Accuracy') |
|
ax[0].set_ylabel('Accuracy (%)') |
|
ax[0].set_xlabel('') |
|
ax[0].set_ylim(0, 100) |
|
ax[0].tick_params(axis='x', rotation=10) |
|
for p in acc_plot.patches: |
|
ax[0].annotate(f"{p.get_height():.2f}%", (p.get_x() + p.get_width() / 2., p.get_height()), |
|
ha='center', va='center', xytext=(0, 9), textcoords='offset points') |
|
|
|
time_df = df[df['Metric'] == 'Current Inference Time (s)'] |
|
time_plot = time_df.plot(kind='bar', x='Model', y='Value', ax=ax[1], color=['#d62728']) |
|
ax[1].set_title('Inference Time for This Image') |
|
ax[1].set_ylabel('Time (seconds)') |
|
ax[1].set_xlabel('') |
|
ax[1].tick_params(axis='x', rotation=0) |
|
for p in time_plot.patches: |
|
ax[1].annotate(f"{p.get_height():.4f}s", (p.get_x() + p.get_width() / 2., p.get_height()), |
|
ha='center', va='center', xytext=(0, 9), textcoords='offset points') |
|
|
|
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) |
|
return fig |
|
|
|
|
|
def classify_image(image, model_name): |
|
""" |
|
Takes an image and model name, returns predictions, inference time, |
|
and a comparison chart. |
|
""" |
|
if image is None: |
|
return {}, "Please upload an image first.", None, "Please upload an image to see a comparison." |
|
|
|
pipe = load_pipeline(model_name) |
|
start_time = time.time() |
|
predictions = pipe(Image.fromarray(image)) |
|
end_time = time.time() |
|
|
|
inference_time = end_time - start_time |
|
|
|
top_5_preds = {p['label'].replace("_", " ").title(): p['score'] for p in predictions[:5]} |
|
comparison_fig = create_comparison_chart(model_name, inference_time) |
|
|
|
buf = io.BytesIO() |
|
comparison_fig.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
comparison_img = Image.open(buf) |
|
plt.close(comparison_fig) |
|
|
|
return ( |
|
top_5_preds, |
|
f"Inference Time: {inference_time:.4f} seconds", |
|
comparison_img, |
|
f"Chart shows accuracy for all models and the inference time for the **{model_name}** model on this specific image." |
|
) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: |
|
gr.Markdown("# π Food Classifier: Accuracy vs. Speed") |
|
gr.Markdown( |
|
"Compare two different models for classifying food images from the Food101 dataset. " |
|
"Notice the trade-off: the **Swin** model is more accurate but might be slower, while the **ViT** model is faster but slightly less accurate." |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="numpy", label="Upload a food picture") |
|
model_dropdown = gr.Dropdown( |
|
choices=list(MODEL_INFO.keys()), |
|
value=list(MODEL_INFO.keys())[0], |
|
label="Choose a Model" |
|
) |
|
classify_button = gr.Button("Classify Image", variant="primary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["examples/sushi.jpg", list(MODEL_INFO.keys())[1]], |
|
["examples/pizza.jpg", list(MODEL_INFO.keys())[0]], |
|
["examples/apple_pie.jpg", list(MODEL_INFO.keys())[1]], |
|
], |
|
inputs=[image_input, model_dropdown], |
|
) |
|
|
|
with gr.Column(scale=2): |
|
output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions") |
|
output_time = gr.Textbox(label="Performance") |
|
output_chart = gr.Image(type="pil", label="Model Comparison Chart") |
|
chart_info = gr.Markdown() |
|
|
|
classify_button.click( |
|
fn=classify_image, |
|
inputs=[image_input, model_dropdown], |
|
outputs=[output_label, output_time, output_chart, chart_info] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|