File size: 6,936 Bytes
4e7ec06
1100e65
b09f327
639051f
4e7ec06
 
7a3a01f
110c781
 
4e7ec06
110c781
 
 
953582f
4e7ec06
ce4312e
 
 
639051f
4e7ec06
82b85b5
4e7ec06
 
 
82b85b5
4e7ec06
 
4ed1e63
4e7ec06
53bdf99
110c781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53bdf99
a123d64
53bdf99
110c781
4e7ec06
dce154d
4e7ec06
110c781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e7ec06
110c781
4e7ec06
110c781
 
 
 
 
 
 
 
 
4e7ec06
 
110c781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e7ec06
110c781
 
4e7ec06
b3174ad
dce154d
110c781
b3174ad
04933a2
53bdf99
110c781
4e7ec06
 
 
110c781
 
 
 
53bdf99
110c781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e7ec06
 
 
110c781
4e7ec06
110c781
53bdf99
26cf8bb
4e7ec06
 
 
26cf8bb
4e7ec06
 
 
 
81f702f
6575bf4
4e7ec06
31b9df5
4e7ec06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import base64
import io
import os
import threading
import tempfile
import logging
import openai
from dash import Dash, dcc, html, Input, Output, State, callback
import dash_bootstrap_components as dbc
from pydub import AudioSegment
import requests
from pytube import YouTube
import moviepy.editor as mp

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize the Dash app
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Global variables
generated_file = None
transcription_text = ""

# Set up OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")

# Layout
app.layout = dbc.Container([
    html.H1("Audio/Video Transcription and Diarization App", className="text-center my-4"),
    dbc.Card([
        dbc.CardBody([
            dcc.Upload(
                id='upload-media',
                children=html.Div([
                    'Drag and Drop or ',
                    html.A('Select Audio/Video File')
                ]),
                style={
                    'width': '100%',
                    'height': '60px',
                    'lineHeight': '60px',
                    'borderWidth': '1px',
                    'borderStyle': 'dashed',
                    'borderRadius': '5px',
                    'textAlign': 'center',
                    'margin': '10px'
                },
                multiple=False
            ),
            html.Div(id='output-media-upload'),
            dbc.Input(id="url-input", type="text", placeholder="Enter audio/video URL (including YouTube)", className="mb-3"),
            dbc.Button("Process URL", id="process-url-button", color="primary", className="mb-3"),
            dbc.Spinner(html.Div(id='transcription-status'), color="primary", type="grow"),
            html.H4("Diarized Transcription Preview", className="mt-4"),
            html.Div(id='transcription-preview', style={'whiteSpace': 'pre-wrap'}),
            html.Br(),
            dbc.Button("Download Transcription", id="btn-download", color="primary", className="mt-3", disabled=True),
            dcc.Download(id="download-transcription")
        ])
    ])
], fluid=True)

def process_media(file_path, is_url=False):
    global generated_file, transcription_text
    temp_audio_file = None
    try:
        if is_url:
            if 'youtube.com' in file_path or 'youtu.be' in file_path:
                yt = YouTube(file_path)
                stream = yt.streams.filter(only_audio=True).first()
                temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
                stream.download(output_path=os.path.dirname(temp_audio_file.name), filename=os.path.basename(temp_audio_file.name))
            else:
                response = requests.get(file_path)
                temp_audio_file = tempfile.NamedTemporaryFile(delete=False)
                temp_audio_file.write(response.content)
                temp_audio_file.close()
        else:
            temp_audio_file = tempfile.NamedTemporaryFile(delete=False)
            temp_audio_file.write(file_path)
            temp_audio_file.close()

        file_extension = os.path.splitext(temp_audio_file.name)[1].lower()
        
        if file_extension in ['.mp4', '.avi', '.mov', '.flv', '.wmv']:
            video = mp.VideoFileClip(temp_audio_file.name)
            audio = video.audio
            wav_path = temp_audio_file.name + ".wav"
            audio.write_audiofile(wav_path)
            video.close()
        elif file_extension in ['.wav', '.mp3', '.ogg', '.flac']:
            audio = AudioSegment.from_file(temp_audio_file.name)
            wav_path = temp_audio_file.name + ".wav"
            audio.export(wav_path, format="wav")
        else:
            return "Unsupported file format. Please upload an audio or video file.", False

        with open(wav_path, "rb") as audio_file:
            transcript = openai.Audio.transcribe("whisper-1", audio_file)
            audio_file.seek(0)
            diarized_transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format="verbose_json")
        
        formatted_transcript = ""
        if 'segments' in diarized_transcript:
            for segment in diarized_transcript["segments"]:
                speaker = segment.get('speaker', 'Unknown')
                text = segment.get('text', '')
                formatted_transcript += f"Speaker {speaker}: {text}\n\n"
        else:
            formatted_transcript = transcript.get('text', 'No transcription available.')
        
        transcription_text = formatted_transcript
        generated_file = io.BytesIO(transcription_text.encode())
        return "Transcription and diarization completed successfully!", True
    except Exception as e:
        logger.error(f"Error during processing: {str(e)}")
        return f"An error occurred: {str(e)}", False
    finally:
        if temp_audio_file and os.path.exists(temp_audio_file.name):
            os.unlink(temp_audio_file.name)
        if 'wav_path' in locals() and os.path.exists(wav_path):
            os.unlink(wav_path)

@app.callback(
    [Output('output-media-upload', 'children'),
     Output('transcription-status', 'children'),
     Output('transcription-preview', 'children'),
     Output('btn-download', 'disabled')],
    [Input('upload-media', 'contents'),
     Input('process-url-button', 'n_clicks')],
    [State('upload-media', 'filename'),
     State('url-input', 'value')]
)
def update_output(contents, n_clicks, filename, url):
    ctx = callback_context
    if not ctx.triggered:
        return "No file uploaded or URL processed.", "", "", True

    trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if trigger_id == 'upload-media' and contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        status_message, success = process_media(decoded)
    elif trigger_id == 'process-url-button' and url:
        status_message, success = process_media(url, is_url=True)
    else:
        return "No file uploaded or URL processed.", "", "", True

    if success:
        preview = transcription_text[:1000] + "..." if len(transcription_text) > 1000 else transcription_text
        return f"File processed successfully.", status_message, preview, False
    else:
        return "Processing failed.", status_message, "", True

@app.callback(
    Output("download-transcription", "data"),
    Input("btn-download", "n_clicks"),
    prevent_initial_call=True,
)
def download_transcription(n_clicks):
    if n_clicks is None:
        return None
    return dcc.send_bytes(generated_file.getvalue(), "diarized_transcription.txt")

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=True, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")