bluenevus commited on
Commit
4e7ec06
·
verified ·
1 Parent(s): 4ed1e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -168
app.py CHANGED
@@ -1,109 +1,41 @@
 
1
  import io
2
  import os
3
- import tempfile
4
  import threading
5
- import base64
6
- import logging
7
- from urllib.parse import urlparse
8
-
9
- import dash
10
- from dash import dcc, html, Input, Output, State, callback_context
11
  import dash_bootstrap_components as dbc
12
- from dash.exceptions import PreventUpdate
13
-
14
- import requests
15
- from pytube import YouTube
16
- from pydub import AudioSegment
17
  import openai
 
18
 
19
- # Try different import statements for moviepy
20
- try:
21
- from moviepy.editor import VideoFileClip
22
- except ImportError:
23
- try:
24
- import moviepy.editor as mpy
25
- VideoFileClip = mpy.VideoFileClip
26
- except ImportError:
27
- try:
28
- import moviepy
29
- VideoFileClip = moviepy.VideoFileClip
30
- except ImportError:
31
- logging.error("Failed to import VideoFileClip from moviepy. Please check the installation.")
32
- VideoFileClip = None
33
-
34
- # Set up logging
35
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
36
  logger = logging.getLogger(__name__)
37
 
38
  # Initialize the Dash app
39
- app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
40
-
41
- # Retrieve the OpenAI API key from Hugging Face Spaces
42
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
43
- if not OPENAI_API_KEY:
44
- logger.error("OPENAI_API_KEY not found in environment variables")
45
- raise ValueError("OPENAI_API_KEY not set")
46
-
47
- openai.api_key = OPENAI_API_KEY
48
-
49
- def process_media(contents, filename, url):
50
- logger.info("Starting media processing")
51
- try:
52
- if contents:
53
- content_type, content_string = contents.split(',')
54
- decoded = base64.b64decode(content_string)
55
- suffix = os.path.splitext(filename)[1]
56
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
57
- temp_file.write(decoded)
58
- temp_file_path = temp_file.name
59
- logger.info(f"File uploaded: {temp_file_path}")
60
- elif url:
61
- temp_file_path = download_media(url)
62
- else:
63
- logger.error("No input provided")
64
- raise ValueError("No input provided")
65
 
66
- if temp_file_path.lower().endswith(('.mp4', '.avi', '.mov', '.flv', '.wmv')):
67
- logger.info("Video file detected, extracting audio")
68
- audio_file_path = extract_audio(temp_file_path)
69
- transcript = transcribe_audio(audio_file_path)
70
- os.unlink(audio_file_path)
71
- else:
72
- logger.info("Audio file detected, transcribing directly")
73
- transcript = transcribe_audio(temp_file_path)
74
 
75
- os.unlink(temp_file_path)
76
- return transcript
77
- except Exception as e:
78
- logger.error(f"Error in process_media: {str(e)}")
79
- raise
80
-
81
- def transcribe_audio(file_path):
82
- logger.info(f"Transcribing audio: {file_path}")
83
- try:
84
- with open(file_path, "rb") as audio_file:
85
- transcript = openai.Audio.transcribe("whisper-1", audio_file)
86
- logger.info("Transcription completed successfully")
87
- return transcript["text"]
88
- except Exception as e:
89
- logger.error(f"Error during transcription: {str(e)}")
90
- raise
91
 
 
92
  app.layout = dbc.Container([
 
93
  dbc.Row([
94
- dbc.Col([
95
- html.H1("Audio/Video Transcription App", className="text-center my-4"),
96
- ])
97
- ]),
98
- dbc.Row([
99
  dbc.Col([
100
  dbc.Card([
101
  dbc.CardBody([
102
  dcc.Upload(
103
- id='upload-media',
104
  children=html.Div([
105
  'Drag and Drop or ',
106
- html.A('Select Audio/Video File')
107
  ]),
108
  style={
109
  'width': '100%',
@@ -117,100 +49,107 @@ app.layout = dbc.Container([
117
  },
118
  multiple=False
119
  ),
120
- html.Div(id='file-info', className="mt-2"),
121
- dbc.Input(id="media-url", type="text", placeholder="Enter audio/video URL or YouTube link", className="my-3"),
122
- dbc.Button("Transcribe", id="transcribe-button", color="primary", className="w-100 mb-3"),
123
- dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
124
- html.Div(id="progress-indicator", className="text-center mt-3"),
125
- dbc.Button("Download Transcript", id="download-button", color="secondary", className="w-100 mt-3", style={'display': 'none'}),
126
- dcc.Download(id="download-transcript"),
127
- dcc.Store(id="transcription-store"),
128
- dcc.Interval(id='progress-interval', interval=500, n_intervals=0, disabled=True)
 
 
 
 
 
129
  ])
130
  ])
131
- ], width=12)
132
  ])
133
  ], fluid=True)
134
 
135
- @app.callback(
136
- Output("file-info", "children"),
137
- Input("upload-media", "filename"),
138
- Input("upload-media", "last_modified")
139
- )
140
- def update_file_info(filename, last_modified):
141
- if filename is not None:
142
- return f"File uploaded: {filename}"
143
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  @app.callback(
146
- Output("transcription-output", "children"),
147
- Output("download-button", "style"),
148
- Output("progress-indicator", "children"),
149
- Output("progress-interval", "disabled"),
150
- Output("transcription-store", "data"),
151
- Input("transcribe-button", "n_clicks"),
152
- Input("progress-interval", "n_intervals"),
153
- State("upload-media", "contents"),
154
- State("upload-media", "filename"),
155
- State("media-url", "value"),
156
- State("transcription-store", "data"),
157
- prevent_initial_call=True
158
  )
159
- def update_transcription(n_clicks, n_intervals, contents, filename, url, stored_transcript):
160
- ctx = callback_context
161
- if ctx.triggered_id == "transcribe-button":
162
- if not contents and not url:
163
- raise PreventUpdate
164
-
165
- def transcribe():
166
- try:
167
- return process_media(contents, filename, url)
168
- except Exception as e:
169
- logger.error(f"Transcription failed: {str(e)}")
170
- return f"An error occurred: {str(e)}"
171
-
172
- thread = threading.Thread(target=transcribe)
173
- thread.start()
174
- return html.Div("Processing..."), {'display': 'none'}, "", False, None
175
-
176
- elif ctx.triggered_id == "progress-interval":
177
- if stored_transcript:
178
- return display_transcript(stored_transcript), {'display': 'block'}, "", True, stored_transcript
179
- dots = "." * (n_intervals % 4)
180
- return html.Div("Processing" + dots), {'display': 'none'}, "", False, None
181
-
182
- thread = threading.current_thread()
183
- if hasattr(thread, 'result'):
184
- transcript = thread.result
185
- if transcript and not transcript.startswith("An error occurred"):
186
- logger.info("Transcription successful")
187
- return display_transcript(transcript), {'display': 'block'}, "", True, transcript
188
- else:
189
- logger.error(f"Transcription failed: {transcript}")
190
- return html.Div(transcript), {'display': 'none'}, "", True, None
191
-
192
- return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update
193
-
194
- def display_transcript(transcript):
195
- return dbc.Card([
196
- dbc.CardBody([
197
- html.H5("Transcription Result"),
198
- html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
199
- ])
200
- ])
201
 
202
  @app.callback(
203
- Output("download-transcript", "data"),
204
- Input("download-button", "n_clicks"),
205
- State("transcription-store", "data"),
206
- prevent_initial_call=True
207
  )
208
- def download_transcript(n_clicks, transcript):
209
- if not transcript:
210
- raise PreventUpdate
211
- return dict(content=transcript, filename="transcript.txt")
212
 
213
  if __name__ == '__main__':
214
- logger.info("Starting the Dash application...")
215
  app.run(debug=True, host='0.0.0.0', port=7860)
216
- logger.info("Dash application has finished running.")
 
1
+ import base64
2
  import io
3
  import os
 
4
  import threading
5
+ from dash import Dash, dcc, html, Input, Output, State, callback
 
 
 
 
 
6
  import dash_bootstrap_components as dbc
7
+ import tempfile
8
+ import logging
 
 
 
9
  import openai
10
+ from pydub import AudioSegment
11
 
12
+ # Configure logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
15
 
16
  # Initialize the Dash app
17
+ app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Global variables
20
+ generated_file = None
21
+ transcription_text = ""
 
 
 
 
 
22
 
23
+ # Set up OpenAI API key
24
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Layout
27
  app.layout = dbc.Container([
28
+ html.H1("Audio Transcription and Diarization App", className="text-center my-4"),
29
  dbc.Row([
30
+ # Left card for input
 
 
 
 
31
  dbc.Col([
32
  dbc.Card([
33
  dbc.CardBody([
34
  dcc.Upload(
35
+ id='upload-audio',
36
  children=html.Div([
37
  'Drag and Drop or ',
38
+ html.A('Select Audio File')
39
  ]),
40
  style={
41
  'width': '100%',
 
49
  },
50
  multiple=False
51
  ),
52
+ html.Div(id='output-audio-upload'),
53
+ dbc.Spinner(html.Div(id='transcription-status'), color="primary", type="grow"),
54
+ ])
55
+ ], className="mb-4")
56
+ ], md=6),
57
+ # Right card for output
58
+ dbc.Col([
59
+ dbc.Card([
60
+ dbc.CardBody([
61
+ html.H4("Diarized Transcription Preview", className="card-title"),
62
+ html.Div(id='transcription-preview', style={'whiteSpace': 'pre-wrap'}),
63
+ html.Br(),
64
+ dbc.Button("Download Transcription", id="btn-download", color="primary", className="mt-3", disabled=True),
65
+ dcc.Download(id="download-transcription")
66
  ])
67
  ])
68
+ ], md=6)
69
  ])
70
  ], fluid=True)
71
 
72
+ def transcribe_and_diarize_audio(contents, filename):
73
+ global generated_file, transcription_text
74
+ try:
75
+ content_type, content_string = contents.split(',')
76
+ decoded = base64.b64decode(content_string)
77
+
78
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as temp_audio_file:
79
+ temp_audio_file.write(decoded)
80
+ temp_audio_file_path = temp_audio_file.name
81
+
82
+ logger.info(f"File uploaded: {temp_audio_file_path}")
83
+
84
+ if filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
85
+ logger.info("Audio file detected, transcribing with OpenAI")
86
+
87
+ # Convert audio to wav format if needed
88
+ audio = AudioSegment.from_file(temp_audio_file_path)
89
+ wav_path = temp_audio_file_path + ".wav"
90
+ audio.export(wav_path, format="wav")
91
+
92
+ with open(wav_path, "rb") as audio_file:
93
+ transcript = openai.Audio.transcribe("whisper-1", audio_file)
94
+
95
+ transcription_text = transcript["text"]
96
+
97
+ # Perform diarization (speaker separation)
98
+ diarized_transcript = openai.Audio.transcribe("whisper-1", audio_file, speaker_detection=2)
99
+
100
+ # Format the diarized transcript
101
+ formatted_transcript = ""
102
+ for segment in diarized_transcript["segments"]:
103
+ formatted_transcript += f"Speaker {segment['speaker']}: {segment['text']}\n\n"
104
+
105
+ transcription_text = formatted_transcript
106
+ logger.info("Transcription and diarization completed successfully")
107
+
108
+ # Prepare the transcription for download
109
+ generated_file = io.BytesIO(transcription_text.encode())
110
+ return "Transcription and diarization completed successfully!", True
111
+ else:
112
+ return "Unsupported file format. Please upload an audio file.", False
113
+ except Exception as e:
114
+ logger.error(f"Error during transcription and diarization: {str(e)}")
115
+ return f"An error occurred during transcription and diarization: {str(e)}", False
116
+ finally:
117
+ if os.path.exists(temp_audio_file_path):
118
+ os.unlink(temp_audio_file_path)
119
+ if os.path.exists(wav_path):
120
+ os.unlink(wav_path)
121
 
122
  @app.callback(
123
+ [Output('output-audio-upload', 'children'),
124
+ Output('transcription-status', 'children'),
125
+ Output('transcription-preview', 'children'),
126
+ Output('btn-download', 'disabled')],
127
+ [Input('upload-audio', 'contents')],
128
+ [State('upload-audio', 'filename')]
 
 
 
 
 
 
129
  )
130
+ def update_output(contents, filename):
131
+ if contents is None:
132
+ return "No file uploaded.", "", "", True
133
+
134
+ status_message, success = transcribe_and_diarize_audio(contents, filename)
135
+
136
+ if success:
137
+ preview = transcription_text[:1000] + "..." if len(transcription_text) > 1000 else transcription_text
138
+ return f"File {filename} processed successfully.", status_message, preview, False
139
+ else:
140
+ return f"File {filename} could not be processed.", status_message, "", True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @app.callback(
143
+ Output("download-transcription", "data"),
144
+ Input("btn-download", "n_clicks"),
145
+ prevent_initial_call=True,
 
146
  )
147
+ def download_transcription(n_clicks):
148
+ if n_clicks is None:
149
+ return None
150
+ return dcc.send_bytes(generated_file.getvalue(), "diarized_transcription.txt")
151
 
152
  if __name__ == '__main__':
153
+ print("Starting the Dash application...")
154
  app.run(debug=True, host='0.0.0.0', port=7860)
155
+ print("Dash application has finished running.")