import gradio as gr import torch import librosa import numpy as np import json import os import tempfile import time from datetime import datetime from transformers import WhisperProcessor, WhisperForConditionalGeneration import warnings warnings.filterwarnings("ignore") # ============================================================================= # MODEL LOADING AND CONFIGURATION # ============================================================================= # Configure your model path - UPDATE THIS with your actual model name MODEL_NAME = "AfroLogicInsect/whisper-finetuned-float32" # Replace with your HF model # Global variables for model and processor model = None processor = None model_dtype = None def load_model(): """Load the Whisper model and processor""" global model, processor, model_dtype try: print(f"🔄 Loading model: {MODEL_NAME}") # Load processor processor = WhisperProcessor.from_pretrained(MODEL_NAME) # Load model with appropriate dtype model = WhisperForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, # Use float32 for stability device_map="auto" if torch.cuda.is_available() else None ) model_dtype = torch.float32 # Move to GPU if available if torch.cuda.is_available(): model = model.cuda() print(f"✅ Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("✅ Model loaded on CPU") return True except Exception as e: print(f"❌ Error loading model: {e}") # Fallback to base Whisper model try: print("🔄 Falling back to base Whisper model...") fallback_model = "openai/whisper-small" processor = WhisperProcessor.from_pretrained(fallback_model) model = WhisperForConditionalGeneration.from_pretrained( fallback_model, torch_dtype=torch.float32 ) model_dtype = torch.float32 if torch.cuda.is_available(): model = model.cuda() print(f"✅ Fallback model loaded: {fallback_model}") return True except Exception as e2: print(f"❌ Fallback model loading failed: {e2}") return False # Load model on startup print("🚀 Initializing Whisper Transcription Service...") model_loaded = load_model() # ============================================================================= # CORE TRANSCRIPTION FUNCTIONS # ============================================================================= def transcribe_audio_chunk(audio_chunk, sr=16000): """Transcribe a single audio chunk""" try: # Process with processor inputs = processor( audio_chunk, sampling_rate=sr, return_tensors="pt" ) input_features = inputs.input_features # Handle dtype matching if model_dtype == torch.float16: input_features = input_features.half() else: input_features = input_features.float() # Move to same device as model input_features = input_features.to(model.device) # Generate transcription with torch.no_grad(): try: predicted_ids = model.generate( input_features, language="en", task="transcribe", max_length=448, num_beams=1, do_sample=False, use_cache=True, no_repeat_ngram_size=2 ) transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] return transcription except RuntimeError as gen_error: if "Input type" in str(gen_error) and "bias type" in str(gen_error): # Handle dtype mismatch model.float() input_features = input_features.float() predicted_ids = model.generate( input_features, language="en", task="transcribe", max_length=448, num_beams=1, do_sample=False, no_repeat_ngram_size=2 ) transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] return transcription else: raise gen_error except Exception as e: print(f"❌ Chunk transcription failed: {e}") return None def process_audio_with_timestamps(audio_array, sr=16000, chunk_length=15): """Process audio with timestamps using robust chunking""" try: total_duration = len(audio_array) / sr # Check duration limit (3 minutes = 180 seconds) if total_duration > 180: return { "error": f"⚠️ Audio too long ({total_duration:.1f}s). Maximum allowed: 3 minutes (180s)", "success": False } chunk_samples = chunk_length * sr overlap_samples = int(2 * sr) # 2-second overlap all_segments = [] start = 0 chunk_index = 0 progress_updates = [] while start < len(audio_array): # Define chunk boundaries end = min(start + chunk_samples, len(audio_array)) # Add overlap for better transcription chunk_start_with_overlap = max(0, start - overlap_samples // 2) chunk_end_with_overlap = min(len(audio_array), end + overlap_samples // 2) chunk_audio = audio_array[chunk_start_with_overlap:chunk_end_with_overlap] # Calculate time boundaries start_time = start / sr end_time = end / sr # Update progress progress = (chunk_index + 1) / max(1, int(np.ceil(len(audio_array) / chunk_samples))) * 100 progress_updates.append(f"Processing chunk {chunk_index + 1}: {start_time:.1f}s - {end_time:.1f}s ({progress:.0f}%)") # Transcribe chunk transcription = transcribe_audio_chunk(chunk_audio, sr) if transcription and transcription.strip(): clean_text = transcription.strip() segment = { "start": round(start_time, 2), "end": round(end_time, 2), "text": clean_text, "duration": round(end_time - start_time, 2) } all_segments.append(segment) # Move to next chunk start = end chunk_index += 1 # Remove overlaps between segments cleaned_segments = remove_segment_overlaps(all_segments) if cleaned_segments: full_text = " ".join([seg["text"] for seg in cleaned_segments]) result = { "success": True, "text": full_text, "segments": cleaned_segments, "metadata": { "total_duration": round(total_duration, 2), "num_segments": len(cleaned_segments), "chunk_length": chunk_length, "processing_time": time.time() } } return result else: return { "error": "❌ No transcription could be generated", "success": False } except Exception as e: return { "error": f"❌ Processing failed: {str(e)}", "success": False } def remove_segment_overlaps(segments): """Remove overlapping text between segments""" if len(segments) <= 1: return segments cleaned_segments = [segments[0]] for i in range(1, len(segments)): current_segment = segments[i].copy() previous_text = cleaned_segments[-1]["text"] current_text = current_segment["text"] # Simple overlap detection prev_words = previous_text.lower().split() curr_words = current_text.lower().split() overlap_length = 0 max_check = min(8, len(prev_words), len(curr_words)) for j in range(1, max_check + 1): if prev_words[-j:] == curr_words[:j]: overlap_length = j if overlap_length > 0: remaining_words = current_text.split()[overlap_length:] if remaining_words: current_segment["text"] = " ".join(remaining_words) cleaned_segments.append(current_segment) else: cleaned_segments.append(current_segment) return cleaned_segments # ============================================================================= # GRADIO INTERFACE FUNCTIONS # ============================================================================= def transcribe_file(audio_file): """Handle file upload transcription""" if not model_loaded: return "❌ Model not loaded. Please refresh the page.", None, None if audio_file is None: return "⚠️ Please upload an audio file.", None, None try: # Load audio file audio_array, sr = librosa.load(audio_file, sr=16000) # Check duration duration = len(audio_array) / sr if duration > 180: # 3 minutes return f"⚠️ Audio too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None # Process with timestamps result = process_audio_with_timestamps(audio_array, sr) if result["success"]: # Format output formatted_text = format_transcription_output(result) # Create downloadable files json_file = create_json_download(result, audio_file) srt_file = create_srt_download(result, audio_file) return formatted_text, json_file, srt_file else: return result["error"], None, None except Exception as e: return f"❌ Error processing file: {str(e)}", None, None def transcribe_microphone(audio_data): """Handle microphone recording transcription""" if not model_loaded: return "❌ Model not loaded. Please refresh the page.", None, None if audio_data is None: return "⚠️ No audio recorded. Please record something first.", None, None try: # Extract sample rate and audio array from Gradio audio data sr, audio_array = audio_data # Convert to float32 and normalize if audio_array.dtype != np.float32: audio_array = audio_array.astype(np.float32) if audio_array.max() > 1.0: audio_array = audio_array / 32768.0 # Convert from int16 to float32 # Resample to 16kHz if needed if sr != 16000: audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000) sr = 16000 # Check duration duration = len(audio_array) / sr if duration > 180: # 3 minutes return f"⚠️ Recording too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None if duration < 0.5: # Less than 0.5 seconds return "⚠️ Recording too short. Please record for at least 0.5 seconds.", None, None # Process with timestamps result = process_audio_with_timestamps(audio_array, sr) if result["success"]: # Format output formatted_text = format_transcription_output(result) # Create downloadable files json_file = create_json_download(result, "microphone_recording") srt_file = create_srt_download(result, "microphone_recording") return formatted_text, json_file, srt_file else: return result["error"], None, None except Exception as e: return f"❌ Error processing recording: {str(e)}", None, None def format_transcription_output(result): """Format transcription result for display""" output = [] # Header output.append("🎯 TRANSCRIPTION RESULTS") output.append("=" * 50) # Metadata metadata = result["metadata"] output.append(f"📊 Duration: {metadata['total_duration']}s") output.append(f"📝 Segments: {metadata['num_segments']}") output.append("") # Full text output.append("📄 FULL TRANSCRIPT:") output.append("-" * 30) output.append(result["text"]) output.append("") # Timestamped segments output.append("🕐 TIMESTAMPED SEGMENTS:") output.append("-" * 30) for i, segment in enumerate(result["segments"], 1): start_min = int(segment["start"] // 60) start_sec = int(segment["start"] % 60) end_min = int(segment["end"] // 60) end_sec = int(segment["end"] % 60) time_str = f"{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}" output.append(f"{i:2d}. [{time_str}] {segment['text']}") return "\n".join(output) def create_json_download(result, source_name): """Create JSON file for download""" try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"transcription_{timestamp}.json" # Add metadata result["metadata"]["source"] = os.path.basename(str(source_name)) result["metadata"]["generated_at"] = datetime.now().isoformat() result["metadata"]["model"] = MODEL_NAME with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f: json.dump(result, f, indent=2, ensure_ascii=False) return f.name except Exception as e: print(f"Error creating JSON download: {e}") return None def create_srt_download(result, source_name): """Create SRT subtitle file for download""" try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"subtitles_{timestamp}.srt" srt_content = [] for i, segment in enumerate(result["segments"], 1): start_time = format_time_srt(segment["start"]) end_time = format_time_srt(segment["end"]) srt_content.extend([ str(i), f"{start_time} --> {end_time}", segment["text"], "" ]) with tempfile.NamedTemporaryFile(mode='w', suffix='.srt', delete=False, encoding='utf-8') as f: f.write("\n".join(srt_content)) return f.name except Exception as e: print(f"Error creating SRT download: {e}") return None def format_time_srt(seconds): """Format seconds to SRT time format""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) millis = int((seconds % 1) * 1000) return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" # ============================================================================= # GRADIO INTERFACE # ============================================================================= def create_gradio_interface(): """Create the Gradio interface""" # Custom CSS for better styling css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .title { text-align: center; color: #2d3748; margin-bottom: 2rem; } .subtitle { text-align: center; color: #4a5568; margin-bottom: 1rem; } .output-text { font-family: 'Courier New', monospace; background-color: #f7fafc; padding: 1rem; border-radius: 8px; border: 1px solid #e2e8f0; } .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; padding: 10px; border-radius: 4px; margin: 10px 0; } """ with gr.Blocks(css=css, title="🎙️ Whisper Speech Transcription") as interface: # Header gr.HTML("""

🎙️ Whisper Speech Transcription

Upload an audio file or record your voice to get an AI-powered transcription with timestamps

""") # Warning about limits gr.HTML("""
⚠️ Important: Maximum audio length is 3 minutes (180 seconds). Longer files will be rejected to ensure fair usage for all users.
""") # Model status status_color = "green" if model_loaded else "red" status_text = "✅ Model loaded and ready" if model_loaded else "❌ Model loading failed" gr.HTML(f'

{status_text}

') with gr.Tabs(): # Tab 1: File Upload with gr.TabItem("📁 Upload Audio File"): with gr.Row(): with gr.Column(): audio_file_input = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload"] ) file_transcribe_btn = gr.Button( "🚀 Transcribe File", variant="primary", size="lg" ) with gr.Row(): file_output = gr.Textbox( label="Transcription Results", lines=15, placeholder="Your transcription will appear here...", elem_classes=["output-text"] ) with gr.Row(): with gr.Column(): json_download = gr.File( label="📄 Download JSON", visible=False ) with gr.Column(): srt_download = gr.File( label="📄 Download SRT Subtitles", visible=False ) # Tab 2: Voice Recording with gr.TabItem("🎤 Record Voice"): with gr.Row(): with gr.Column(): audio_mic_input = gr.Audio( label="Record Your Voice", sources=["microphone"], type="numpy" ) mic_transcribe_btn = gr.Button( "🚀 Transcribe Recording", variant="primary", size="lg" ) with gr.Row(): mic_output = gr.Textbox( label="Transcription Results", lines=15, placeholder="Your transcription will appear here...", elem_classes=["output-text"] ) with gr.Row(): with gr.Column(): json_download_mic = gr.File( label="📄 Download JSON", visible=False ) with gr.Column(): srt_download_mic = gr.File( label="📄 Download SRT Subtitles", visible=False ) # Footer gr.HTML("""

📋 Output Formats

JSON: Complete transcription data with timestamps and metadata

SRT: Standard subtitle format for video players

Display: Formatted text with timestamped segments


Powered by Whisper AI | Maximum 3 minutes per audio | English language optimized

""") # Event handlers def update_file_outputs(result_text, json_file, srt_file): json_visible = json_file is not None srt_visible = srt_file is not None return ( result_text, gr.update(value=json_file, visible=json_visible), gr.update(value=srt_file, visible=srt_visible) ) file_transcribe_btn.click( fn=transcribe_file, inputs=[audio_file_input], outputs=[file_output, json_download, srt_download] ).then( fn=update_file_outputs, inputs=[file_output, json_download, srt_download], outputs=[file_output, json_download, srt_download] ) mic_transcribe_btn.click( fn=transcribe_microphone, inputs=[audio_mic_input], outputs=[mic_output, json_download_mic, srt_download_mic] ).then( fn=update_file_outputs, inputs=[mic_output, json_download_mic, srt_download_mic], outputs=[mic_output, json_download_mic, srt_download_mic] ) return interface # ============================================================================= # LAUNCH APPLICATION # ============================================================================= if __name__ == "__main__": # Create and launch the interface interface = create_gradio_interface() # Launch configuration interface.launch( share=True, # Creates a public URL server_name="0.0.0.0", # Allows external access server_port=7860, # Standard Gradio port show_error=True, # enable_queue=True, # Handle multiple users max_threads=10 # Limit concurrent processing )