Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import librosa | |
import tensorflow as tf | |
from scipy.fftpack import dct | |
import os | |
import tempfile | |
import shutil | |
import subprocess | |
import re | |
import requests | |
from io import BytesIO | |
# DSCNN model configuration | |
MODEL_PATH = "ds_cnn_l_quantized.tflite" | |
DEFAULT_CONFIG = "u55_eval_with_TA_config_400_and_200_MHz.ini" | |
# Keywords based on Speech Commands dataset (12 classes) | |
KEYWORDS = [ | |
"silence", "unknown", "yes", "no", "up", "down", | |
"left", "right", "on", "off", "stop", "go" | |
] | |
print("Loading DSCNN TensorFlow Lite model...") | |
try: | |
# Load the TFLite model | |
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) | |
interpreter.allocate_tensors() | |
# Get input and output details | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
print(f"✅ DSCNN model loaded successfully!") | |
print(f"Input shape: {input_details[0]['shape']}") | |
print(f"Output shape: {output_details[0]['shape']}") | |
print(f"Input dtype: {input_details[0]['dtype']}") | |
print(f"Output dtype: {output_details[0]['dtype']}") | |
except Exception as e: | |
print(f"❌ Error loading DSCNN model: {e}") | |
interpreter = None | |
# Vela config file is copied from SR app | |
def extract_summary_from_log(log_text): | |
summary_keys = [ | |
"Accelerator configuration", | |
"Accelerator clock", | |
"Total SRAM used", | |
"Total On-chip Flash used", | |
"CPU operators", | |
"NPU operators", | |
"Batch Inference time" | |
] | |
summary = [] | |
for key in summary_keys: | |
match = re.search(rf"{re.escape(key)}\s+(.+)", log_text) | |
if match: | |
value = match.group(1).strip() | |
if key == "Batch Inference time": | |
value = value.split(",")[0].strip() | |
key = "Inference time" | |
summary.append((key, value)) | |
return summary | |
def run_vela(model_file): | |
accel = "ethos-u55-128" | |
optimise = "Size" | |
mem_mode = "Sram_Only" | |
sys_config = "Ethos_U55_400MHz_SRAM_3.2_GBs_Flash_0.05_GBs" | |
tmpdir = tempfile.mkdtemp() | |
try: | |
# Use the original uploaded model filename | |
original_model_name = os.path.basename(model_file) | |
model_path = os.path.join(tmpdir, original_model_name) | |
shutil.copy(model_file, model_path) | |
config_path = os.path.join(tmpdir, DEFAULT_CONFIG) | |
shutil.copy(DEFAULT_CONFIG, config_path) | |
output_dir = os.path.join(tmpdir, "vela_out") | |
os.makedirs(output_dir, exist_ok=True) | |
cmd = [ | |
"vela", | |
f"--accelerator-config={accel}", | |
f"--optimise={optimise}", | |
f"--config={config_path}", | |
f"--memory-mode={mem_mode}", | |
f"--system-config={sys_config}", | |
model_path, | |
"--verbose-cycle-estimate", | |
"--verbose-performance", | |
f"--output-dir={output_dir}" | |
] | |
result = subprocess.run(cmd, capture_output=True, text=True, check=True) | |
vela_stdout = result.stdout | |
# Check for unsupported model patterns in logs | |
unsupported_patterns = [ | |
"Warning: Unsupported TensorFlow Lite semantics", | |
"Network Tops/s nan Tops/s", | |
"Neural network macs 0 MACs/batch" | |
] | |
if any(pat in vela_stdout for pat in unsupported_patterns): | |
summary_html = ( | |
"<div class='sr110-results card' style='background:#fafafa;border-radius:12px;box-shadow:0 4px 6px -1px rgba(0,0,0,0.1),0 2px 4px -1px rgba(0,0,0,0.06);border:1px solid #e5e7eb;margin-bottom:1.5rem;max-width:500px;width:100%;margin:auto;overflow:hidden;'>" | |
"<div class='card-header' style='background:linear-gradient(135deg,#dc2626 0%,#b91c1c 100%);color:white;padding:1rem 1.5rem;border-radius:12px 12px 0 0;font-weight:600;font-size:1.1rem;'>" | |
"<span style='color:white;font-weight:600;'>Unsupported Model</span>" | |
"</div>" | |
"<div class='card-content' style='padding:1.5rem;color:#4b5563;line-height:1.6;background:#fafafa;text-align:center;'>" | |
"This model has unsupported layers and needs investigation based on layers.<br><br>" | |
"Please use Vela compiler on your Host Machine for further analysis." | |
"</div></div>" | |
) | |
# Try to provide per-layer.csv if available for download | |
per_layer_csv = None | |
for log_fname in os.listdir(output_dir): | |
if log_fname.endswith("per-layer.csv"): | |
per_layer_csv = os.path.join("/tmp", log_fname) | |
shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
break | |
return summary_html, None, per_layer_csv | |
model_filename = os.path.basename(model_file) | |
if model_filename: | |
vela_stdout = vela_stdout.replace( | |
"Network summary for", | |
f"Network summary for {model_filename} (" | |
) | |
summary_items = extract_summary_from_log(vela_stdout) | |
# Convert summary_items to dict for easy access | |
summary_dict = dict(summary_items) if summary_items else {} | |
# Build 4 cards for results | |
def clean_ops(val): | |
# Remove '=' and leading/trailing spaces | |
return val.lstrip("= ").strip() if isinstance(val, str) else val | |
summary_html = ( | |
"<div class='sr110-results card' style='background:#fafafa;border-radius:12px;box-shadow:0 4px 6px -1px rgba(0,0,0,0.1),0 2px 4px -1px rgba(0,0,0,0.06);border:1px solid #e5e7eb;margin-bottom:1.5rem;max-width:500px;width:100%;margin:auto;overflow:hidden;'>" | |
"<div class='card-header' style='background:linear-gradient(135deg,#1975cf 0%,#1557b0 100%);color:white;padding:1rem 1.5rem;border-radius:12px 12px 0 0;font-weight:600;font-size:1.1rem;'>" | |
"<span style='color:white;font-weight:600;'>Estimated Results on SR110</span>" | |
"</div>" | |
"<div class='card-content' style='padding:1.5rem;color:#4b5563;line-height:1.6;background:#fafafa;'>" | |
"<div style='display:grid;grid-template-columns:1fr 1fr;gap:1.5rem;'>" | |
# Card 1: Accelerator | |
"<div class='stat-item' style='background:#f8fafc;padding:1rem;border-radius:8px;border-left:4px solid #1975cf;'>" | |
"<div class='stat-label' style='font-weight:600;color:#1975cf;font-size:0.9rem;margin-bottom:0.5rem;'>Accelerator</div>" | |
f"<div class='stat-value' style='color:#4b5563;font-size:0.85rem;'><strong>Configuration:</strong> {summary_dict.get('Accelerator configuration','-')}<br><strong>Clock:</strong> {summary_dict.get('Accelerator clock','-')}</div>" | |
"</div>" | |
# Card 2: Memory Usage | |
"<div class='stat-item' style='background:#f8fafc;padding:1rem;border-radius:8px;border-left:4px solid #1975cf;'>" | |
"<div class='stat-label' style='font-weight:600;color:#1975cf;font-size:0.9rem;margin-bottom:0.5rem;'>Memory Usage</div>" | |
f"<div class='stat-value' style='color:#4b5563;font-size:0.85rem;'><strong>Total SRAM:</strong> {summary_dict.get('Total SRAM used','-')}<br><strong>Total Flash:</strong> {summary_dict.get('Total On-chip Flash used','-')}</div>" | |
"</div>" | |
# Card 3: Operator Distribution | |
"<div class='stat-item' style='background:#f8fafc;padding:1rem;border-radius:8px;border-left:4px solid #1975cf;'>" | |
"<div class='stat-label' style='font-weight:600;color:#1975cf;font-size:0.9rem;margin-bottom:0.5rem;'>Operator Distribution</div>" | |
f"<div class='stat-value' style='color:#4b5563;font-size:0.85rem;'><strong>CPU Operators:</strong> {clean_ops(summary_dict.get('CPU operators','-'))}<br><strong>NPU Operators:</strong> {clean_ops(summary_dict.get('NPU operators','-'))}</div>" | |
"</div>" | |
# Card 4: Performance | |
"<div class='stat-item' style='background:#f8fafc;padding:1rem;border-radius:8px;border-left:4px solid #1975cf;'>" | |
"<div class='stat-label' style='font-weight:600;color:#1975cf;font-size:0.9rem;margin-bottom:0.5rem;'>Performance</div>" | |
f"<div class='stat-value' style='color:#4b5563;font-size:0.85rem;'><strong>Inference time:</strong> {summary_dict.get('Inference time','-')}</div>" | |
"</div>" | |
"</div></div></div>" | |
) if summary_items else "<div style='color:red'>Summary info not found in log.</div>" | |
for fname in os.listdir(output_dir): | |
if fname.endswith("vela.tflite"): | |
final_path = os.path.join("/tmp", fname) | |
shutil.copy(os.path.join(output_dir, fname), final_path) | |
# Find per-layer.csv file for logs | |
per_layer_csv = None | |
for log_fname in os.listdir(output_dir): | |
if log_fname.endswith("per-layer.csv"): | |
per_layer_csv = os.path.join("/tmp", log_fname) | |
shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
break | |
return summary_html, final_path, per_layer_csv | |
# If no tflite, still try to return per-layer.csv if present | |
per_layer_csv = None | |
for log_fname in os.listdir(output_dir): | |
if log_fname.endswith("per-layer.csv"): | |
per_layer_csv = os.path.join("/tmp", log_fname) | |
shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
break | |
return summary_html, None, per_layer_csv | |
finally: | |
shutil.rmtree(tmpdir) | |
# Run Vela analysis on startup and cache results | |
print("Running Vela analysis on DSCNN model...") | |
try: | |
vela_html, compiled_model, per_layer_csv = run_vela(MODEL_PATH) | |
except Exception as e: | |
vela_html = f"<div style='color:red'>Vela analysis failed: {str(e)}</div>" | |
def extract_mfcc_features(audio_path, target_length=490): | |
""" | |
Extract MFCC features exactly as specified in the original DSCNN paper. | |
Based on "Hello Edge: Keyword Spotting on Microcontrollers" | |
Parameters from paper: | |
- 40ms frame length (640 samples at 16kHz) | |
- 20ms stride (320 samples at 16kHz) | |
- 10 MFCC features per frame | |
- 49 frames total for 1 second → 49×10 = 490 features | |
""" | |
try: | |
# Load audio and resample to 16kHz (standard for speech commands) | |
audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
# Ensure audio is exactly 1 second (16000 samples) | |
if len(audio) < 16000: | |
# Pad with zeros | |
audio = np.pad(audio, (0, 16000 - len(audio)), 'constant') | |
else: | |
# Truncate to 1 second | |
audio = audio[:16000] | |
# DSCNN paper parameters | |
frame_length = 640 # 40ms at 16kHz | |
hop_length = 320 # 20ms at 16kHz (50% overlap) | |
n_mfcc = 10 # 10 MFCC features as in paper | |
n_fft = 1024 # FFT size | |
n_mels = 40 # Mel filter bank size (before DCT) | |
# Extract mel spectrogram | |
mel_spec = librosa.feature.melspectrogram( | |
y=audio, | |
sr=sr, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
win_length=frame_length, | |
n_mels=n_mels, | |
fmin=20, | |
fmax=4000 | |
) | |
# Convert to log scale | |
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
# Apply DCT to get MFCC features (only take first 10 coefficients) | |
mfcc_features = dct(log_mel_spec, axis=0, norm='ortho')[:n_mfcc, :] | |
# Should be shape (10, 49) for 1 second of audio | |
print(f"MFCC shape before flattening: {mfcc_features.shape}") | |
# Flatten to 1D array (10 × 49 = 490 features) | |
features_flat = mfcc_features.flatten() | |
# Ensure exactly 490 features | |
if len(features_flat) > target_length: | |
features_flat = features_flat[:target_length] | |
elif len(features_flat) < target_length: | |
features_flat = np.pad(features_flat, (0, target_length - len(features_flat)), 'constant') | |
print(f"Features length after processing: {len(features_flat)}") | |
# Normalize features (zero mean, unit variance) | |
features_flat = (features_flat - np.mean(features_flat)) / (np.std(features_flat) + 1e-8) | |
# Quantize to INT8 range for DSCNN model | |
# Scale to approximately match training distribution | |
features_int8 = np.clip(features_flat * 127.0, -128, 127).astype(np.int8) | |
return features_int8.reshape(1, -1) # Shape: (1, 490) | |
except Exception as e: | |
raise Exception(f"Error extracting MFCC features: {str(e)}") | |
def classify_audio(audio_input): | |
""" | |
Classify the input audio using the DSCNN model and return keyword predictions. | |
""" | |
if audio_input is None: | |
return "Please upload an audio file or record audio." | |
if interpreter is None: | |
return "❌ DSCNN model not loaded. Please refresh the page and try again." | |
try: | |
# Extract MFCC features | |
features = extract_mfcc_features(audio_input) | |
print(f"Input features shape: {features.shape}") | |
print(f"Input features dtype: {features.dtype}") | |
print(f"Input features range: [{features.min()}, {features.max()}]") | |
# Set input tensor | |
interpreter.set_tensor(input_details[0]['index'], features) | |
# Run inference | |
interpreter.invoke() | |
# Get output | |
output_data = interpreter.get_tensor(output_details[0]['index']) | |
print(f"Raw output shape: {output_data.shape}") | |
print(f"Raw output dtype: {output_data.dtype}") | |
print(f"Raw output range: [{output_data.min()}, {output_data.max()}]") | |
# Handle quantized INT8 output | |
if output_data.dtype == np.int8: | |
# Dequantize INT8 to float (assuming symmetric quantization) | |
# Scale factor is typically around 1/128 for INT8 | |
logits = output_data.astype(np.float32) / 128.0 | |
else: | |
logits = output_data.astype(np.float32) | |
# Apply softmax to get probabilities | |
exp_logits = np.exp(logits - np.max(logits)) | |
probabilities = exp_logits / np.sum(exp_logits) | |
# Get predictions with confidence scores | |
predictions = [] | |
for i, prob in enumerate(probabilities[0]): | |
predictions.append({ | |
'label': KEYWORDS[i], | |
'score': float(prob) | |
}) | |
# Sort by confidence score | |
predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) | |
# Format results | |
results = [] | |
for i, pred in enumerate(predictions[:5]): | |
confidence = pred['score'] * 100 | |
label = pred['label'] | |
indicator = "🎯" if i == 0 else " " | |
results.append(f"{indicator} {i+1}. **{label}**: {confidence:.1f}%") | |
return "\n".join(results) | |
except Exception as e: | |
error_msg = str(e) | |
if "mfcc" in error_msg.lower() or "librosa" in error_msg.lower(): | |
return "❌ Audio processing error. Please ensure your audio file is in a supported format (WAV, MP3, etc.)" | |
elif "model" in error_msg.lower() or "tensor" in error_msg.lower(): | |
return "❌ Model inference error. Please try recording a clear 1-second audio clip." | |
else: | |
return f"❌ Error processing audio: {error_msg}\n\nTip: Try recording a clear 1-second word like 'yes' or 'stop'." | |
def load_example_audio(example_name): | |
"""Load example audio for demonstration.""" | |
# This would load pre-recorded examples if available | |
return None | |
def compile_uploaded_model(model_file): | |
"""Compile uploaded model with Vela and return results""" | |
if model_file is None: | |
error_html = ( | |
"<div class='sr110-results card' style='background:#fafafa;border-radius:12px;box-shadow:0 4px 6px -1px rgba(0,0,0,0.1),0 2px 4px -1px rgba(0,0,0,0.06);border:1px solid #e5e7eb;margin-bottom:1.5rem;max-width:500px;width:100%;margin:auto;overflow:hidden;'>" | |
"<div class='card-header' style='background:linear-gradient(135deg,#dc2626 0%,#b91c1c 100%);color:white;padding:1rem 1.5rem;border-radius:12px 12px 0 0;font-weight:600;font-size:1.1rem;'>" | |
"<span style='color:white;font-weight:600;'>No Model</span>" | |
"</div>" | |
"<div class='card-content' style='padding:1.5rem;color:#4b5563;line-height:1.6;background:#fafafa;text-align:center;'>" | |
"No model file uploaded." | |
"</div></div>" | |
) | |
return ( | |
error_html, | |
gr.update(visible=False, value=None), | |
gr.update(visible=False, value=None) | |
) | |
try: | |
# Run Vela analysis on uploaded model | |
results_html, compiled_model_path, per_layer_csv = run_vela(model_file) | |
return ( | |
results_html, | |
gr.update(visible=compiled_model_path is not None, value=compiled_model_path), | |
gr.update(visible=per_layer_csv is not None, value=per_layer_csv) | |
) | |
except Exception as e: | |
error_html = ( | |
"<div class='sr110-results card' style='background:#fafafa;border-radius:12px;box-shadow:0 4px 6px -1px rgba(0,0,0,0.1),0 2px 4px -1px rgba(0,0,0,0.06);border:1px solid #e5e7eb;margin-bottom:1.5rem;max-width:500px;width:100%;margin:auto;overflow:hidden;'>" | |
"<div class='card-header' style='background:linear-gradient(135deg,#dc2626 0%,#b91c1c 100%);color:white;padding:1rem 1.5rem;border-radius:12px 12px 0 0;font-weight:600;font-size:1.1rem;'>" | |
"<span style='color:white;font-weight:600;'>Compilation Failed</span>" | |
"</div>" | |
"<div class='card-content' style='padding:1.5rem;color:#4b5563;line-height:1.6;background:#fafafa;text-align:center;'>" | |
f"Vela compilation failed: {str(e)}" | |
"</div></div>" | |
) | |
return ( | |
error_html, | |
gr.update(visible=False, value=None), | |
gr.update(visible=False, value=None) | |
) | |
# Create Gradio interface | |
with gr.Blocks( | |
theme=gr.themes.Default(primary_hue="blue", neutral_hue="gray"), | |
title="DSCNN Wake Word Detection", | |
css=""" | |
body { | |
background: #fafafa !important; | |
} | |
.gradio-container { | |
max-width: none !important; | |
margin: 0 !important; | |
background-color: #fafafa !important; | |
font-family: 'Inter', 'Segoe UI', -apple-system, sans-serif !important; | |
width: 100vw !important; | |
} | |
.gr-row { | |
display: flex !important; | |
justify-content: center !important; | |
align-items: flex-start !important; | |
gap: 48px !important; | |
} | |
.gr-column { | |
align-items: flex-start !important; | |
justify-content: flex-start !important; | |
} | |
.fixed-upload-box { | |
width: 100% !important; | |
max-width: 420px !important; | |
margin-bottom: 18px !important; | |
} | |
.download-btn-custom, .compile-btn-custom { | |
width: 100% !important; | |
margin-bottom: 18px !important; | |
} | |
.upload-file-box .w-full, .download-file-box .w-full { | |
height: 120px !important; | |
background: #232b36 !important; | |
border-radius: 8px !important; | |
color: #fff !important; | |
font-weight: 600 !important; | |
font-size: 1.1em !important; | |
box-shadow: none !important; | |
display: flex !important; | |
align-items: center !important; | |
justify-content: center !important; | |
} | |
.upload-file-box .w-full .file-preview { | |
margin: 0 auto !important; | |
text-align: center !important; | |
width: 100%; | |
} | |
#run-vela-btn, .compile-btn, .compile-btn-custom { | |
background-color: #007dc3 !important; | |
color: white !important; | |
font-size: 1.1em; | |
border-radius: 8px; | |
margin-top: 12px; | |
margin-bottom: 18px; | |
text-align: center; | |
height: 40px !important; | |
} | |
.results-summary-box, #results-summary { | |
margin-left: 0 !important; | |
} | |
h1, h3, .gr-markdown h1, .gr-markdown h3 { color: #1976d2 !important; } | |
p, .gr-markdown p, .gr-markdown span, .gr-markdown { color: #222 !important; } | |
.custom-footer { | |
display: block !important; | |
margin: 40px auto 0 auto !important; | |
max-width: 600px !important; | |
width: 100% !important; | |
background: #e6f4ff !important; | |
border-radius: 10px !important; | |
box-shadow: 0 2px 2px #0001 !important; | |
padding: 24px 32px 24px 32px !important; | |
font-size: 1.13em !important; | |
color: #0a2540 !important; | |
font-family: sans-serif !important; | |
text-align: center !important; | |
position: relative !important; | |
z-index: 10 !important; | |
} | |
.custom-footer a { | |
color: #0074d9 !important; | |
text-decoration: underline !important; | |
font-weight: 700 !important; | |
} | |
.card { | |
background: #fafafa !important; | |
border-radius: 12px !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; | |
border: 1px solid #e5e7eb !important; | |
margin-bottom: 1.5rem !important; | |
transition: all 0.2s ease-in-out !important; | |
overflow: hidden !important; | |
} | |
.card > * { | |
padding: 0 !important; | |
margin: 0 !important; | |
} | |
.card:hover { | |
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important; | |
transform: translateY(-1px) !important; | |
} | |
.card-header { | |
background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important; | |
color: white !important; | |
padding: 1rem 1.5rem !important; | |
border-radius: 12px 12px 0 0 !important; | |
font-weight: 600 !important; | |
font-size: 1.1rem !important; | |
} | |
.card-header, | |
div.card-header, | |
div.card-header span, | |
div.card-header * { | |
color: white !important; | |
} | |
.card-content { | |
padding: 1.5rem !important; | |
color: #4b5563 !important; | |
line-height: 1.6 !important; | |
background: #fafafa !important; | |
} | |
.stats-grid { | |
display: grid !important; | |
grid-template-columns: 1fr 1fr !important; | |
gap: 1.5rem !important; | |
margin-top: 1.5rem !important; | |
} | |
.stat-item { | |
background: #f8fafc !important; | |
padding: 1rem !important; | |
border-radius: 8px !important; | |
border-left: 4px solid #1975cf !important; | |
} | |
.stat-label { | |
font-weight: 600 !important; | |
color: #4b5563 !important; | |
font-size: 0.9rem !important; | |
margin-bottom: 0.5rem !important; | |
} | |
.stat-value { | |
color: #4b5563 !important; | |
font-size: 0.85rem !important; | |
} | |
.btn-example { | |
background: #f1f5f9 !important; | |
border: 1px solid #cbd5e1 !important; | |
color: #4b5563 !important; | |
border-radius: 6px !important; | |
transition: all 0.2s ease !important; | |
margin: 0.35rem !important; | |
padding: 0.5rem 1rem !important; | |
} | |
.btn-example:hover { | |
background: #1975cf !important; | |
border-color: #1975cf !important; | |
color: white !important; | |
} | |
.btn-primary { | |
background: #1975cf !important; | |
border-color: #1975cf !important; | |
color: white !important; | |
} | |
.btn-primary:hover { | |
background: #1557b0 !important; | |
border-color: #1557b0 !important; | |
} | |
.markdown { | |
color: #374151 !important; | |
} | |
.results-text { | |
color: #4b5563 !important; | |
font-weight: 500 !important; | |
padding: 0 !important; | |
margin: 0 !important; | |
} | |
.results-text p { | |
color: #4b5563 !important; | |
margin: 0.5rem 0 !important; | |
} | |
.results-text * { | |
color: #4b5563 !important; | |
} | |
div[data-testid="markdown"] p { | |
color: #4b5563 !important; | |
} | |
.prose { | |
color: #4b5563 !important; | |
} | |
.prose * { | |
color: #4b5563 !important; | |
} | |
.card-header, | |
.card-header * { | |
color: white !important; | |
} | |
/* Override grey colors for SR110 Vela results section - MUST be after prose rules */ | |
.prose .sr110-results, | |
.prose .sr110-results *, | |
.prose .sr110-results h3, | |
.prose .sr110-results div, | |
.prose .sr110-results span, | |
.sr110-results, | |
.sr110-results *, | |
.sr110-results h3, | |
.sr110-results div, | |
.sr110-results span { | |
color: inherit !important; | |
} | |
/* Preserve original colors for dark theme cards with higher specificity */ | |
.prose .sr110-results .sr110-card, | |
.sr110-results .sr110-card { | |
background: #23233a !important; | |
} | |
.prose .sr110-results .sr110-title, | |
.sr110-results .sr110-title { | |
color: #00b0ff !important; | |
} | |
.prose .sr110-results .sr110-label, | |
.sr110-results .sr110-label { | |
color: #ccc !important; | |
} | |
.prose .sr110-results .sr110-value, | |
.sr110-results .sr110-value { | |
color: #fff !important; | |
} | |
""" | |
) as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>DSCNN Wake Word Detection</h1> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_audio = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="Record or Upload Audio", | |
value=None | |
) | |
classify_btn = gr.Button( | |
"Detect Wake Word", | |
variant="primary", | |
size="lg", | |
elem_classes=["btn-primary"] | |
) | |
with gr.Group(elem_classes=["card"]): | |
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Supported Keywords</span></div>') | |
with gr.Column(elem_classes=["card-content"]): | |
gr.HTML(""" | |
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 0.5rem; text-align: center;"> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">yes</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">no</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">up</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">down</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">left</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">right</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">on</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">off</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">stop</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">go</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">silence</div> | |
<div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">unknown</div> | |
</div> | |
""") | |
with gr.Column(scale=1): | |
# Display Vela analysis results (dynamic) | |
vela_results_html = gr.HTML(vela_html) | |
with gr.Group(elem_classes=["card"]): | |
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Wake Word Detection Results</span></div>') | |
with gr.Column(elem_classes=["card-content"]): | |
output_text = gr.Markdown( | |
value="Record or upload audio to see wake word predictions...", | |
label="", | |
elem_classes=["results-text"] | |
) | |
# Set up event handlers | |
classify_btn.click( | |
fn=classify_audio, | |
inputs=input_audio, | |
outputs=output_text | |
) | |
# Auto-classify when audio is uploaded | |
input_audio.change( | |
fn=classify_audio, | |
inputs=input_audio, | |
outputs=output_text | |
) | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.launch() |