model-deployer's picture
Upload folder using huggingface_hub
8e8eb27 verified
import gradio as gr
import time
import torch
from transformers import pipeline
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import io
# --- 1. Model Configuration & Metadata ---
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_INFO = {
"ViT (eslamxm/vit-base-food101)": {
"model_id": "eslamxm/vit-base-food101",
"benchmark_accuracy": 90.68,
"pipeline": None
},
"Swin (aspis/swin-finetuned-food101)": {
"model_id": "aspis/swin-finetuned-food101",
"benchmark_accuracy": 93.81,
"pipeline": None
}
}
# --- 2. Lazy Loading of Models ---
def load_pipeline(model_name):
"""Loads a model pipeline only when it's first needed."""
if MODEL_INFO[model_name]["pipeline"] is None:
print(f"Loading model: {model_name}...")
model_id = MODEL_INFO[model_name]["model_id"]
MODEL_INFO[model_name]["pipeline"] = pipeline(task="image-classification", model=model_id, device=DEVICE)
print(f"Model '{model_name}' loaded on {DEVICE}.")
return MODEL_INFO[model_name]["pipeline"]
# --- 3. Function to Generate Comparison Chart ---
def create_comparison_chart(selected_model_name, current_inference_time):
"""Generates a bar chart comparing model accuracy and inference time."""
data = {'Model': [], 'Metric': [], 'Value': []}
for name, info in MODEL_INFO.items():
data['Model'].append(name)
data['Metric'].append('Benchmark Accuracy (%)')
data['Value'].append(info['benchmark_accuracy'])
data['Model'].append(selected_model_name)
data['Metric'].append('Current Inference Time (s)')
data['Value'].append(current_inference_time)
df = pd.DataFrame(data)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle('Model Performance Comparison', fontsize=16)
acc_df = df[df['Metric'] == 'Benchmark Accuracy (%)']
colors_acc = ['#4c72b0' if model != selected_model_name else '#2ca02c' for model in acc_df['Model']]
acc_plot = acc_df.plot(kind='bar', x='Model', y='Value', ax=ax[0], color=colors_acc, legend=None)
ax[0].set_title('Benchmark Accuracy')
ax[0].set_ylabel('Accuracy (%)')
ax[0].set_xlabel('')
ax[0].set_ylim(0, 100)
ax[0].tick_params(axis='x', rotation=10)
for p in acc_plot.patches:
ax[0].annotate(f"{p.get_height():.2f}%", (p.get_x() + p.get_width() / 2., p.get_height()),
ha='center', va='center', xytext=(0, 9), textcoords='offset points')
time_df = df[df['Metric'] == 'Current Inference Time (s)']
time_plot = time_df.plot(kind='bar', x='Model', y='Value', ax=ax[1], color=['#d62728'])
ax[1].set_title('Inference Time for This Image')
ax[1].set_ylabel('Time (seconds)')
ax[1].set_xlabel('')
ax[1].tick_params(axis='x', rotation=0)
for p in time_plot.patches:
ax[1].annotate(f"{p.get_height():.4f}s", (p.get_x() + p.get_width() / 2., p.get_height()),
ha='center', va='center', xytext=(0, 9), textcoords='offset points')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
return fig
# --- 4. The Core Classification Function ---
def classify_image(image, model_name):
"""
Takes an image and model name, returns predictions, inference time,
and a comparison chart.
"""
if image is None:
return {}, "Please upload an image first.", None, "Please upload an image to see a comparison."
pipe = load_pipeline(model_name)
start_time = time.time()
predictions = pipe(Image.fromarray(image))
end_time = time.time()
inference_time = end_time - start_time
top_5_preds = {p['label'].replace("_", " ").title(): p['score'] for p in predictions[:5]}
comparison_fig = create_comparison_chart(model_name, inference_time)
buf = io.BytesIO()
comparison_fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
comparison_img = Image.open(buf)
plt.close(comparison_fig)
return (
top_5_preds,
f"Inference Time: {inference_time:.4f} seconds",
comparison_img,
f"Chart shows accuracy for all models and the inference time for the **{model_name}** model on this specific image."
)
# --- 5. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown("# πŸ” Food Classifier: Accuracy vs. Speed")
gr.Markdown(
"Compare two different models for classifying food images from the Food101 dataset. "
"Notice the trade-off: the **Swin** model is more accurate but might be slower, while the **ViT** model is faster but slightly less accurate."
)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Upload a food picture")
model_dropdown = gr.Dropdown(
choices=list(MODEL_INFO.keys()),
value=list(MODEL_INFO.keys())[0],
label="Choose a Model"
)
classify_button = gr.Button("Classify Image", variant="primary")
gr.Examples(
examples=[
["examples/sushi.jpg", list(MODEL_INFO.keys())[1]],
["examples/pizza.jpg", list(MODEL_INFO.keys())[0]],
["examples/apple_pie.jpg", list(MODEL_INFO.keys())[1]],
],
inputs=[image_input, model_dropdown],
)
with gr.Column(scale=2):
output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions")
output_time = gr.Textbox(label="Performance")
output_chart = gr.Image(type="pil", label="Model Comparison Chart")
chart_info = gr.Markdown()
classify_button.click(
fn=classify_image,
inputs=[image_input, model_dropdown],
outputs=[output_label, output_time, output_chart, chart_info]
)
if __name__ == "__main__":
demo.launch()