A / app.py
latterworks's picture
Update app.py
daf4d3a verified
import gradio as gr
import subprocess
import tempfile
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage
from pathlib import Path
import logging
import warnings
import shutil
from typing import Tuple, Optional
# Configure matplotlib and logging
plt.switch_backend('Agg')
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class AudioAnalyzer:
def __init__(self):
self.temp_dir = Path(tempfile.mkdtemp())
self.plot_files = []
def cleanup(self):
for plot_file in self.plot_files:
Path(plot_file).unlink(missing_ok=True)
shutil.rmtree(self.temp_dir, ignore_errors=True)
def download_youtube_audio(self, video_url: str, progress=gr.Progress()) -> Tuple[Optional[str], str]:
if not video_url:
return None, "Please provide a valid YouTube URL"
progress(0.1, desc="Downloading...")
output_file = self.temp_dir / "audio.mp3"
try:
subprocess.run([
"yt-dlp", "-x", "--audio-format", "mp3",
"-o", str(output_file), video_url
], check=True, capture_output=True)
progress(1.0, desc="Complete!")
return str(output_file), "Download successful"
except FileNotFoundError:
return None, "yt-dlp not found. Install with: pip install yt-dlp"
except subprocess.CalledProcessError as e:
return None, f"Download failed: {e.stderr}"
def save_plot(self, fig) -> str:
plot_path = self.temp_dir / f"plot_{len(self.plot_files)}.png"
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close(fig)
self.plot_files.append(str(plot_path))
return str(plot_path)
def analyze_audio(self, audio_path: str, analysis_type: str = "basic",
patch_duration: float = 5.0, progress=gr.Progress()) -> Tuple[Optional[str], str]:
if not audio_path or not Path(audio_path).exists():
return None, "No audio file provided"
try:
progress(0.1, desc="Loading audio...")
y, sr = librosa.load(audio_path, sr=22050)
duration = len(y) / sr
# Limit duration for processing
max_duration = 60 if analysis_type == "basic" else 30
if duration > max_duration:
y = y[:int(sr * max_duration)]
duration = max_duration
if analysis_type == "basic":
return self._basic_analysis(y, sr, duration, progress)
elif analysis_type == "chroma":
return self._chroma_analysis(y, sr, progress)
elif analysis_type == "patches":
return self._patch_analysis(y, sr, patch_duration, progress)
except Exception as e:
logger.error(f"Analysis error: {e}")
return None, f"Analysis failed: {str(e)}"
def _basic_analysis(self, y, sr, duration, progress):
progress(0.3, desc="Computing features...")
# Extract features
tempo = float(librosa.beat.beat_track(y=y, sr=sr)[0])
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
progress(0.6, desc="Creating visualizations...")
# Create mel spectrogram
S_mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
S_dB = librosa.power_to_db(S_mel, ref=np.max)
# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Waveform
time = np.linspace(0, duration, len(y))
axes[0, 0].plot(time, y, alpha=0.8)
axes[0, 0].set_title('Waveform', fontweight='bold')
axes[0, 0].set_xlabel('Time (s)')
# Mel Spectrogram
librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel', ax=axes[0, 1])
axes[0, 1].set_title('Mel Spectrogram', fontweight='bold')
# MFCC
librosa.display.specshow(mfcc, sr=sr, x_axis='time', ax=axes[1, 0])
axes[1, 0].set_title('MFCC Features', fontweight='bold')
# Spectral features
times = librosa.frames_to_time(range(len(spectral_centroid)), sr=sr)
axes[1, 1].plot(times, spectral_centroid, label='Centroid', linewidth=2)
axes[1, 1].plot(times, spectral_rolloff, label='Rolloff', linewidth=2)
axes[1, 1].set_title('Spectral Features', fontweight='bold')
axes[1, 1].legend()
axes[1, 1].set_xlabel('Time (s)')
plt.tight_layout()
plot_path = self.save_plot(fig)
summary = f"""**Audio Analysis Results**
- Duration: {duration:.1f}s | Sample Rate: {sr:,} Hz
- Tempo: {tempo:.1f} BPM | Samples: {len(y):,}
- MFCC shape: {mfcc.shape} | Features extracted successfully"""
progress(1.0, desc="Complete!")
return plot_path, summary
def _chroma_analysis(self, y, sr, progress):
progress(0.3, desc="Computing chroma features...")
# Different chroma extraction methods
chroma_cqt = librosa.feature.chroma_cqt(y=y, sr=sr)
chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)
# Harmonic separation
y_harm = librosa.effects.harmonic(y=y)
chroma_harm = librosa.feature.chroma_cqt(y=y_harm, sr=sr)
progress(0.7, desc="Creating visualizations...")
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Plot different chroma features
chromas = [
(chroma_cqt, 'Chroma (CQT)'),
(chroma_stft, 'Chroma (STFT)'),
(chroma_harm, 'Harmonic Chroma'),
(chroma_cqt - chroma_harm, 'Chroma Difference')
]
for i, (chroma, title) in enumerate(chromas):
ax = axes[i//2, i%2]
librosa.display.specshow(chroma, y_axis='chroma', x_axis='time', ax=ax)
ax.set_title(title, fontweight='bold')
plt.tight_layout()
plot_path = self.save_plot(fig)
summary = f"""**Chroma Analysis Results**
- Multiple chroma extraction methods compared
- CQT vs STFT analysis | Harmonic separation applied
- Chroma shape: {chroma_cqt.shape}"""
progress(1.0, desc="Complete!")
return plot_path, summary
def _patch_analysis(self, y, sr, patch_duration, progress):
progress(0.3, desc="Generating patches...")
# Create mel spectrogram
hop_length = 512
S_mel = librosa.feature.melspectrogram(y=y, sr=sr, hop_length=hop_length, n_mels=80)
S_dB = librosa.power_to_db(S_mel, ref=np.max)
# Generate patches
patch_frames = librosa.time_to_frames(patch_duration, sr=sr, hop_length=hop_length)
hop_frames = patch_frames // 2 # 50% overlap
patches = librosa.util.frame(S_dB, frame_length=patch_frames, hop_length=hop_frames)
progress(0.7, desc="Creating visualizations...")
# Show first 6 patches
num_show = min(6, patches.shape[-1])
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
for i in range(num_show):
librosa.display.specshow(patches[..., i], y_axis='mel', x_axis='time',
ax=axes[i], sr=sr, hop_length=hop_length)
axes[i].set_title(f'Patch {i+1}', fontweight='bold')
# Hide unused subplots
for i in range(num_show, 6):
axes[i].set_visible(False)
plt.tight_layout()
plot_path = self.save_plot(fig)
summary = f"""**Patch Generation Results**
- Total patches: {patches.shape[-1]} | Duration: {patch_duration}s each
- Patch shape: {patches.shape} | 50% overlap between patches
- Ready for transformer input"""
progress(1.0, desc="Complete!")
return plot_path, summary
def create_interface():
analyzer = AudioAnalyzer()
with gr.Blocks(title="Audio Analysis Suite") as demo:
gr.Markdown("# 🎵 Audio Analysis Suite")
with gr.Row():
with gr.Column():
# Input section
gr.Markdown("### Input")
youtube_url = gr.Textbox(label="YouTube URL", placeholder="https://youtube.com/watch?v=...")
download_btn = gr.Button("Download Audio")
audio_file = gr.Audio(label="Or upload audio file", type="filepath")
# Analysis options
gr.Markdown("### Analysis Options")
analysis_type = gr.Radio(
choices=["basic", "chroma", "patches"],
value="basic",
label="Analysis Type"
)
patch_duration = gr.Slider(1, 10, 5, step=0.5, label="Patch Duration (s)",
visible=False)
analyze_btn = gr.Button("Analyze Audio", variant="primary")
with gr.Column():
# Results
gr.Markdown("### Results")
plot_output = gr.Image(label="Visualizations")
summary_output = gr.Markdown()
status_output = gr.Textbox(label="Status", interactive=False)
# Event handlers
download_btn.click(
analyzer.download_youtube_audio,
inputs=[youtube_url],
outputs=[audio_file, status_output]
)
analyze_btn.click(
analyzer.analyze_audio,
inputs=[audio_file, analysis_type, patch_duration],
outputs=[plot_output, summary_output]
)
# Show patch duration slider only for patches analysis
analysis_type.change(
lambda x: gr.update(visible=(x == "patches")),
inputs=[analysis_type],
outputs=[patch_duration]
)
demo.unload(analyzer.cleanup)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()