fosters's picture
Update app.py
5721eb7 verified
# 1. Install the necessary libraries first:
# pip install gradio optimum[onnxruntime] transformers
import gradio as gr
from optimum.pipelines import pipeline # Use the pipeline from 'optimum'
import os
# --- Performance Improvement ---
# Configure ONNX Runtime to use all available CPU cores.
# This is done by setting the OMP_NUM_THREADS environment variable.
num_cpu_cores = os.cpu_count()
if num_cpu_cores is not None:
os.environ["OMP_NUM_THREADS"] = str(num_cpu_cores)
print(f"βœ… ONNX Runtime configured to use {num_cpu_cores} CPU cores.")
else:
print("Could not determine the number of CPU cores. Using default settings.")
# 2. Initialize the pipeline using the ONNX model from the Hub.
# 'optimum' handles downloading the model and running it with the specified accelerator.
pipe = pipeline(
task="audio-classification",
model="onnx-community/ast-finetuned-audioset-10-10-0.4593-ONNX",
accelerator="ort", # Specifies to use ONNX Runtime ('ort')
device="cpu", # Explicitly run on the CPU
feature_extractor_kwargs={"use_fast": True} # Silences the "slow processor" warning
)
# Define the function to classify an audio file
def classify_audio(audio_filepath):
"""
Takes an audio file path, classifies it using the ONNX pipeline,
and returns a dictionary of top labels and their scores.
"""
if audio_filepath is None:
return "Please upload an audio file first."
# The 'optimum' pipeline works just like the 'transformers' one
result = pipe(audio_filepath)
return {label['label']: label['score'] for label in result}
# Set up the Gradio interface
app = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio"),
outputs=gr.Label(num_top_classes=3, label="Top 3 Predictions"),
title="High-Performance Audio Classification with ONNX",
description="Upload an audio file to classify it. This app uses a pre-optimized ONNX model and runs on all available CPU cores for maximum speed.",
examples=[
# You can add local example audio files here if you have them
# ["path/to/example_cat_purr.wav"],
]
)
# Launch the app
if __name__ == "__main__":
app.launch()