File size: 4,734 Bytes
3bb71cf |
1 |
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: dialogue_diarization_demo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio gradio torch torchaudio pyannote.audio openai-whisper librosa numpy transformers speechbrain "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import gradio as gr\n", "from pyannote.audio import Pipeline\n", "import whisper\n", "\n", "diarization_pipeline = None\n", "whisper_model = None\n", "\n", "\n", "def load_models():\n", " global diarization_pipeline, whisper_model # noqa: PLW0603\n", "\n", " if diarization_pipeline is None:\n", " diarization_pipeline = Pipeline.from_pretrained(\n", " \"pyannote/speaker-diarization-3.1\", use_auth_token=True\n", " )\n", "\n", " if whisper_model is None:\n", " whisper_model = whisper.load_model(\"base\")\n", "\n", "\n", "def real_diarization(audio_file_path: str) -> list[dict[str, str]]:\n", " try:\n", " load_models()\n", "\n", " if diarization_pipeline is None or whisper_model is None:\n", " raise Exception(\"Failed to load models\")\n", "\n", " diarization = diarization_pipeline(audio_file_path)\n", "\n", " transcription = whisper_model.transcribe(audio_file_path)\n", " segments = transcription[\"segments\"]\n", "\n", " dialogue_segments = []\n", " speaker_mapping = {}\n", " speaker_counter = 1\n", "\n", " for segment in segments:\n", " start_time = segment[\"start\"]\n", " end_time = segment[\"end\"]\n", " text = segment[\"text\"].strip()\n", "\n", " speaker = \"Speaker 1\"\n", " for turn, _, speaker_label in diarization.itertracks(yield_label=True):\n", " if (\n", " turn.start <= start_time <= turn.end\n", " or turn.start <= end_time <= turn.end\n", " ):\n", " if speaker_label not in speaker_mapping:\n", " speaker_mapping[speaker_label] = f\"Speaker {speaker_counter}\"\n", " speaker_counter += 1\n", " speaker = speaker_mapping[speaker_label]\n", " break\n", "\n", " if text:\n", " dialogue_segments.append({\"speaker\": speaker, \"text\": text})\n", "\n", " return dialogue_segments\n", "\n", " except Exception as e:\n", " print(f\"Error in diarization: {str(e)}\")\n", " return []\n", "\n", "\n", "def process_audio(audio_file):\n", " if audio_file is None:\n", " gr.Warning(\"Please upload an audio file first.\")\n", " return []\n", "\n", " try:\n", " dialogue_segments = real_diarization(audio_file)\n", " return dialogue_segments\n", " except Exception as e:\n", " gr.Error(f\"Error processing audio: {str(e)}\")\n", " return []\n", "\n", "\n", "speakers = [\n", " \"Speaker 1\",\n", " \"Speaker 2\",\n", " \"Speaker 3\",\n", " \"Speaker 4\",\n", " \"Speaker 5\",\n", " \"Speaker 6\",\n", "]\n", "tags = [\n", " \"(pause)\",\n", " \"(background noise)\",\n", " \"(unclear)\",\n", " \"(overlap)\",\n", " \"(phone ringing)\",\n", " \"(door closing)\",\n", " \"(music)\",\n", " \"(applause)\",\n", " \"(laughter)\",\n", "]\n", "\n", "\n", "def format_speaker(speaker, text):\n", " return f\"{speaker}: {text}\"\n", "\n", "\n", "with gr.Blocks(title=\"Audio Diarization Demo\") as demo:\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " audio_input = gr.Audio(\n", " label=\"Upload Audio File\",\n", " type=\"filepath\",\n", " sources=[\"upload\", \"microphone\"],\n", " )\n", "\n", " process_btn = gr.Button(\"\ud83d\udd0d Analyze Speakers\", variant=\"primary\", size=\"lg\")\n", "\n", " with gr.Column(scale=2):\n", " dialogue_output = gr.Dialogue(\n", " speakers=speakers,\n", " tags=tags,\n", " formatter=format_speaker,\n", " label=\"AI-generated speaker-separated conversation\",\n", " value=[],\n", " )\n", "\n", " process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[dialogue_output])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |