File size: 4,734 Bytes
3bb71cf
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: dialogue_diarization_demo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio gradio torch torchaudio pyannote.audio openai-whisper librosa numpy transformers speechbrain "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import gradio as gr\n", "from pyannote.audio import Pipeline\n", "import whisper\n", "\n", "diarization_pipeline = None\n", "whisper_model = None\n", "\n", "\n", "def load_models():\n", "    global diarization_pipeline, whisper_model  # noqa: PLW0603\n", "\n", "    if diarization_pipeline is None:\n", "        diarization_pipeline = Pipeline.from_pretrained(\n", "            \"pyannote/speaker-diarization-3.1\", use_auth_token=True\n", "        )\n", "\n", "    if whisper_model is None:\n", "        whisper_model = whisper.load_model(\"base\")\n", "\n", "\n", "def real_diarization(audio_file_path: str) -> list[dict[str, str]]:\n", "    try:\n", "        load_models()\n", "\n", "        if diarization_pipeline is None or whisper_model is None:\n", "            raise Exception(\"Failed to load models\")\n", "\n", "        diarization = diarization_pipeline(audio_file_path)\n", "\n", "        transcription = whisper_model.transcribe(audio_file_path)\n", "        segments = transcription[\"segments\"]\n", "\n", "        dialogue_segments = []\n", "        speaker_mapping = {}\n", "        speaker_counter = 1\n", "\n", "        for segment in segments:\n", "            start_time = segment[\"start\"]\n", "            end_time = segment[\"end\"]\n", "            text = segment[\"text\"].strip()\n", "\n", "            speaker = \"Speaker 1\"\n", "            for turn, _, speaker_label in diarization.itertracks(yield_label=True):\n", "                if (\n", "                    turn.start <= start_time <= turn.end\n", "                    or turn.start <= end_time <= turn.end\n", "                ):\n", "                    if speaker_label not in speaker_mapping:\n", "                        speaker_mapping[speaker_label] = f\"Speaker {speaker_counter}\"\n", "                        speaker_counter += 1\n", "                    speaker = speaker_mapping[speaker_label]\n", "                    break\n", "\n", "            if text:\n", "                dialogue_segments.append({\"speaker\": speaker, \"text\": text})\n", "\n", "        return dialogue_segments\n", "\n", "    except Exception as e:\n", "        print(f\"Error in diarization: {str(e)}\")\n", "        return []\n", "\n", "\n", "def process_audio(audio_file):\n", "    if audio_file is None:\n", "        gr.Warning(\"Please upload an audio file first.\")\n", "        return []\n", "\n", "    try:\n", "        dialogue_segments = real_diarization(audio_file)\n", "        return dialogue_segments\n", "    except Exception as e:\n", "        gr.Error(f\"Error processing audio: {str(e)}\")\n", "        return []\n", "\n", "\n", "speakers = [\n", "    \"Speaker 1\",\n", "    \"Speaker 2\",\n", "    \"Speaker 3\",\n", "    \"Speaker 4\",\n", "    \"Speaker 5\",\n", "    \"Speaker 6\",\n", "]\n", "tags = [\n", "    \"(pause)\",\n", "    \"(background noise)\",\n", "    \"(unclear)\",\n", "    \"(overlap)\",\n", "    \"(phone ringing)\",\n", "    \"(door closing)\",\n", "    \"(music)\",\n", "    \"(applause)\",\n", "    \"(laughter)\",\n", "]\n", "\n", "\n", "def format_speaker(speaker, text):\n", "    return f\"{speaker}: {text}\"\n", "\n", "\n", "with gr.Blocks(title=\"Audio Diarization Demo\") as demo:\n", "    with gr.Row():\n", "        with gr.Column(scale=1):\n", "            audio_input = gr.Audio(\n", "                label=\"Upload Audio File\",\n", "                type=\"filepath\",\n", "                sources=[\"upload\", \"microphone\"],\n", "            )\n", "\n", "            process_btn = gr.Button(\"\ud83d\udd0d Analyze Speakers\", variant=\"primary\", size=\"lg\")\n", "\n", "        with gr.Column(scale=2):\n", "            dialogue_output = gr.Dialogue(\n", "                speakers=speakers,\n", "                tags=tags,\n", "                formatter=format_speaker,\n", "                label=\"AI-generated speaker-separated conversation\",\n", "                value=[],\n", "            )\n", "\n", "    process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[dialogue_output])\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}