File size: 3,405 Bytes
3bb71cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# type: ignore
import gradio as gr
from pyannote.audio import Pipeline
import whisper

diarization_pipeline = None
whisper_model = None


def load_models():
    global diarization_pipeline, whisper_model  # noqa: PLW0603

    if diarization_pipeline is None:
        diarization_pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1", use_auth_token=True
        )

    if whisper_model is None:
        whisper_model = whisper.load_model("base")


def real_diarization(audio_file_path: str) -> list[dict[str, str]]:
    try:
        load_models()

        if diarization_pipeline is None or whisper_model is None:
            raise Exception("Failed to load models")

        diarization = diarization_pipeline(audio_file_path)

        transcription = whisper_model.transcribe(audio_file_path)
        segments = transcription["segments"]

        dialogue_segments = []
        speaker_mapping = {}
        speaker_counter = 1

        for segment in segments:
            start_time = segment["start"]
            end_time = segment["end"]
            text = segment["text"].strip()

            speaker = "Speaker 1"
            for turn, _, speaker_label in diarization.itertracks(yield_label=True):
                if (
                    turn.start <= start_time <= turn.end
                    or turn.start <= end_time <= turn.end
                ):
                    if speaker_label not in speaker_mapping:
                        speaker_mapping[speaker_label] = f"Speaker {speaker_counter}"
                        speaker_counter += 1
                    speaker = speaker_mapping[speaker_label]
                    break

            if text:
                dialogue_segments.append({"speaker": speaker, "text": text})

        return dialogue_segments

    except Exception as e:
        print(f"Error in diarization: {str(e)}")
        return []


def process_audio(audio_file):
    if audio_file is None:
        gr.Warning("Please upload an audio file first.")
        return []

    try:
        dialogue_segments = real_diarization(audio_file)
        return dialogue_segments
    except Exception as e:
        gr.Error(f"Error processing audio: {str(e)}")
        return []


speakers = [
    "Speaker 1",
    "Speaker 2",
    "Speaker 3",
    "Speaker 4",
    "Speaker 5",
    "Speaker 6",
]
tags = [
    "(pause)",
    "(background noise)",
    "(unclear)",
    "(overlap)",
    "(phone ringing)",
    "(door closing)",
    "(music)",
    "(applause)",
    "(laughter)",
]


def format_speaker(speaker, text):
    return f"{speaker}: {text}"


with gr.Blocks(title="Audio Diarization Demo") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(
                label="Upload Audio File",
                type="filepath",
                sources=["upload", "microphone"],
            )

            process_btn = gr.Button("🔍 Analyze Speakers", variant="primary", size="lg")

        with gr.Column(scale=2):
            dialogue_output = gr.Dialogue(
                speakers=speakers,
                tags=tags,
                formatter=format_speaker,
                label="AI-generated speaker-separated conversation",
                value=[],
            )

    process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[dialogue_output])

if __name__ == "__main__":
    demo.launch()