bluenevus commited on
Commit
81f702f
·
verified ·
1 Parent(s): df42ab3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -88
app.py CHANGED
@@ -1,22 +1,15 @@
1
  import io
2
- import re
3
  import torch
4
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
5
  import requests
6
  from bs4 import BeautifulSoup
7
  import tempfile
8
  import os
9
- import soundfile as sf
10
- from spellchecker import SpellChecker
11
  from pydub import AudioSegment
12
- import librosa
13
- import numpy as np
14
- from pyannote.audio import Pipeline
15
  import dash
16
  from dash import dcc, html, Input, Output, State
17
  import dash_bootstrap_components as dbc
18
  from dash.exceptions import PreventUpdate
19
- import base64
20
  import threading
21
  from pytube import YouTube
22
 
@@ -31,8 +24,6 @@ model_name = "openai/whisper-small"
31
  processor = WhisperProcessor.from_pretrained(model_name)
32
  model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
33
 
34
- spell = SpellChecker()
35
-
36
  def download_audio_from_url(url):
37
  try:
38
  if "youtube.com" in url or "youtu.be" in url:
@@ -66,92 +57,35 @@ def download_audio_from_url(url):
66
  print(f"Error in download_audio_from_url: {str(e)}")
67
  raise
68
 
69
- def correct_spelling(text):
70
- words = text.split()
71
- corrected_words = [spell.correction(word) or word for word in words]
72
- return ' '.join(corrected_words)
73
-
74
- def format_transcript_with_speakers(transcript, diarization):
75
- formatted_transcript = []
76
- current_speaker = None
77
- for segment, _, speaker in diarization.itertracks(yield_label=True):
78
- start = segment.start
79
- end = segment.end
80
- if speaker != current_speaker:
81
- if current_speaker is not None:
82
- formatted_transcript.append("\n") # Add a blank line between speakers
83
- formatted_transcript.append(f"Speaker {speaker}:\n")
84
- current_speaker = speaker
85
- segment_text = transcript[start:end].strip()
86
- if segment_text:
87
- formatted_transcript.append(f"{segment_text}\n")
88
- return "".join(formatted_transcript)
89
-
90
- def transcribe_audio(audio_file, pipeline):
91
  try:
92
- if pipeline is None:
93
- raise ValueError("Speaker diarization pipeline is not initialized")
94
-
95
  print("Loading audio file...")
96
- audio_input, sr = librosa.load(audio_file, sr=16000)
97
- audio_input = audio_input.astype(np.float32)
98
- print(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
99
-
100
- # Apply speaker diarization
101
- print("Applying speaker diarization...")
102
- diarization = pipeline(audio_file)
103
- print("Speaker diarization complete.")
104
-
105
- chunk_length = 30 * sr
106
- overlap = 5 * sr
107
- transcriptions = []
108
 
109
  print("Starting transcription...")
110
- for i in range(0, len(audio_input), chunk_length - overlap):
111
- chunk = audio_input[i:i+chunk_length]
112
- input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
113
- predicted_ids = model.generate(input_features)
114
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
115
- transcriptions.extend(transcription)
116
- print(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
117
-
118
- full_transcription = " ".join(transcriptions)
119
- print(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
120
-
121
- print("Applying formatting with speaker diarization...")
122
- formatted_transcription = format_transcript_with_speakers(full_transcription, diarization)
123
-
124
- return formatted_transcription
125
  except Exception as e:
126
  print(f"Error in transcribe_audio: {str(e)}")
127
  raise
128
 
129
- def transcribe_video(url, pipeline):
130
  try:
131
  print(f"Attempting to download audio from URL: {url}")
132
  audio_bytes = download_audio_from_url(url)
133
  print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
134
 
135
- # Convert audio bytes to AudioSegment
136
- audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
137
-
138
- print(f"Audio duration: {len(audio) / 1000} seconds")
139
-
140
- # Save as WAV file
141
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
142
- audio.export(temp_audio.name, format="wav")
143
- temp_audio_path = temp_audio.name
144
-
145
- print("Starting audio transcription...")
146
- transcript = transcribe_audio(temp_audio_path, pipeline)
147
- print(f"Transcription completed. Transcript length: {len(transcript)} characters")
148
 
149
- # Clean up the temporary file
150
- os.unlink(temp_audio_path)
151
-
152
- # Apply spelling correction
153
- transcript = correct_spelling(transcript)
154
-
155
  return transcript
156
  except Exception as e:
157
  error_message = f"An error occurred: {str(e)}"
@@ -189,13 +123,7 @@ def update_transcription(n_clicks, url):
189
 
190
  def transcribe():
191
  try:
192
- # Initialize the speaker diarization pipeline without token
193
- pipeline = Pipeline.from_pretrained("collinbarnwell/pyannote-speaker-diarization-31")
194
- if pipeline is None:
195
- raise ValueError("Failed to initialize the speaker diarization pipeline")
196
- print("Speaker diarization pipeline initialized successfully")
197
-
198
- transcript = transcribe_video(url, pipeline)
199
  return transcript
200
  except Exception as e:
201
  return f"An error occurred: {str(e)}"
@@ -218,7 +146,9 @@ def update_transcription(n_clicks, url):
218
  ]), download_data
219
  else:
220
  return transcript, None
221
-
 
 
222
  if __name__ == '__main__':
223
  print("Starting the Dash application...")
224
  app.run(debug=True, host='0.0.0.0', port=7860)
 
1
  import io
 
2
  import torch
3
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
  import requests
5
  from bs4 import BeautifulSoup
6
  import tempfile
7
  import os
 
 
8
  from pydub import AudioSegment
 
 
 
9
  import dash
10
  from dash import dcc, html, Input, Output, State
11
  import dash_bootstrap_components as dbc
12
  from dash.exceptions import PreventUpdate
 
13
  import threading
14
  from pytube import YouTube
15
 
 
24
  processor = WhisperProcessor.from_pretrained(model_name)
25
  model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
26
 
 
 
27
  def download_audio_from_url(url):
28
  try:
29
  if "youtube.com" in url or "youtu.be" in url:
 
57
  print(f"Error in download_audio_from_url: {str(e)}")
58
  raise
59
 
60
+ def transcribe_audio(audio_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
 
 
 
62
  print("Loading audio file...")
63
+ audio = AudioSegment.from_file(audio_file)
64
+ audio = audio.set_channels(1).set_frame_rate(16000)
65
+ audio_array = audio.get_array_of_samples()
 
 
 
 
 
 
 
 
 
66
 
67
  print("Starting transcription...")
68
+ input_features = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
69
+ predicted_ids = model.generate(input_features)
70
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
71
+
72
+ print(f"Transcription complete. Length: {len(transcription[0])} characters")
73
+ return transcription[0]
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
  print(f"Error in transcribe_audio: {str(e)}")
76
  raise
77
 
78
+ def transcribe_video(url):
79
  try:
80
  print(f"Attempting to download audio from URL: {url}")
81
  audio_bytes = download_audio_from_url(url)
82
  print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
83
 
 
 
 
 
 
 
84
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
85
+ AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
86
+ transcript = transcribe_audio(temp_audio.name)
 
 
 
 
87
 
88
+ os.unlink(temp_audio.name)
 
 
 
 
 
89
  return transcript
90
  except Exception as e:
91
  error_message = f"An error occurred: {str(e)}"
 
123
 
124
  def transcribe():
125
  try:
126
+ transcript = transcribe_video(url)
 
 
 
 
 
 
127
  return transcript
128
  except Exception as e:
129
  return f"An error occurred: {str(e)}"
 
146
  ]), download_data
147
  else:
148
  return transcript, None
149
+
150
+ print("Reached end of script definitions")
151
+
152
  if __name__ == '__main__':
153
  print("Starting the Dash application...")
154
  app.run(debug=True, host='0.0.0.0', port=7860)