bluenevus commited on
Commit
26cf8bb
·
verified ·
1 Parent(s): 5adda7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -22,12 +22,12 @@ print(f"Using device: {device}")
22
  # Load the Whisper model and processor
23
  whisper_model_name = "openai/whisper-small"
24
  whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
25
- whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name, torch_dtype=torch.float16).to(device)
26
 
27
  # Load the Qwen model and tokenizer
28
  qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
29
- qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
30
- qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype=torch.float16).to(device)
31
 
32
  def download_audio_from_url(url):
33
  try:
@@ -69,9 +69,10 @@ def transcribe_audio(audio_file):
69
  audio = audio.set_channels(1).set_frame_rate(16000)
70
  audio_array = torch.tensor(audio.get_array_of_samples()).float()
71
 
 
72
  print("Starting transcription...")
73
- input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device).to(torch.float16)
74
- predicted_ids = whisper_model.generate(input_features, language='en', task='translate')
75
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
76
 
77
  print(f"Transcription complete. Length: {len(transcription[0])} characters")
@@ -81,6 +82,7 @@ def transcribe_audio(audio_file):
81
  raise
82
 
83
  def separate_speakers(transcription):
 
84
  prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:
85
 
86
  1. Label speakers as "Speaker 1", "Speaker 2", etc.
@@ -94,12 +96,14 @@ Now, please process the following transcribed text:
94
  """
95
 
96
  inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
97
- outputs = qwen_model.generate(**inputs, max_new_tokens=1000)
 
98
  result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
99
 
100
  # Extract the processed text (remove the instruction part)
101
  processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
102
 
 
103
  return processed_text
104
 
105
  def transcribe_video(url):
@@ -134,7 +138,10 @@ app.layout = dbc.Container([
134
  dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
135
  dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
136
  dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
137
- dcc.Download(id="download-transcript")
 
 
 
138
  ])
139
  ])
140
  ], width=12)
@@ -143,7 +150,7 @@ app.layout = dbc.Container([
143
 
144
  @app.callback(
145
  Output("transcription-output", "children"),
146
- Output("download-transcript", "data"),
147
  Input("transcribe-button", "n_clicks"),
148
  State("video-url", "value"),
149
  prevent_initial_call=True
@@ -157,28 +164,41 @@ def update_transcription(n_clicks, url):
157
  transcript = transcribe_video(url)
158
  return transcript
159
  except Exception as e:
160
- return f"An error occurred: {str(e)}"
 
161
 
162
  # Run transcription in a separate thread
163
  thread = threading.Thread(target=transcribe)
164
  thread.start()
165
- thread.join()
 
 
 
166
 
167
  transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"
168
 
169
  if transcript and not transcript.startswith("An error occurred"):
170
- download_data = dict(content=transcript, filename="transcript.txt")
171
  return dbc.Card([
172
  dbc.CardBody([
173
  html.H5("Transcription Result with Speaker Separation"),
174
- html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
175
- dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
176
  ])
177
- ]), download_data
178
  else:
179
- return transcript, None
180
 
181
- print("Reached end of script definitions")
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == '__main__':
184
  print("Starting the Dash application...")
 
22
  # Load the Whisper model and processor
23
  whisper_model_name = "openai/whisper-small"
24
  whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
25
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
26
 
27
  # Load the Qwen model and tokenizer
28
  qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
29
+ qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
30
+ qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device)
31
 
32
  def download_audio_from_url(url):
33
  try:
 
69
  audio = audio.set_channels(1).set_frame_rate(16000)
70
  audio_array = torch.tensor(audio.get_array_of_samples()).float()
71
 
72
+ print(f"Audio duration: {len(audio) / 1000:.2f} seconds")
73
  print("Starting transcription...")
74
+ input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
75
+ predicted_ids = whisper_model.generate(input_features)
76
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
77
 
78
  print(f"Transcription complete. Length: {len(transcription[0])} characters")
 
82
  raise
83
 
84
  def separate_speakers(transcription):
85
+ print("Starting speaker separation...")
86
  prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:
87
 
88
  1. Label speakers as "Speaker 1", "Speaker 2", etc.
 
96
  """
97
 
98
  inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
99
+ with torch.no_grad():
100
+ outputs = qwen_model.generate(**inputs, max_new_tokens=4000)
101
  result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
102
 
103
  # Extract the processed text (remove the instruction part)
104
  processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
105
 
106
+ print("Speaker separation complete.")
107
  return processed_text
108
 
109
  def transcribe_video(url):
 
138
  dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
139
  dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
140
  dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
141
+ html.Div([
142
+ dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
143
+ dcc.Download(id="download-transcript")
144
+ ])
145
  ])
146
  ])
147
  ], width=12)
 
150
 
151
  @app.callback(
152
  Output("transcription-output", "children"),
153
+ Output("download-button", "style"),
154
  Input("transcribe-button", "n_clicks"),
155
  State("video-url", "value"),
156
  prevent_initial_call=True
 
164
  transcript = transcribe_video(url)
165
  return transcript
166
  except Exception as e:
167
+ import traceback
168
+ return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
169
 
170
  # Run transcription in a separate thread
171
  thread = threading.Thread(target=transcribe)
172
  thread.start()
173
+ thread.join(timeout=600) # 10 minutes timeout
174
+
175
+ if thread.is_alive():
176
+ return "Transcription timed out after 10 minutes", {'display': 'none'}
177
 
178
  transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"
179
 
180
  if transcript and not transcript.startswith("An error occurred"):
 
181
  return dbc.Card([
182
  dbc.CardBody([
183
  html.H5("Transcription Result with Speaker Separation"),
184
+ html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
 
185
  ])
186
+ ]), {'display': 'block'}
187
  else:
188
+ return transcript, {'display': 'none'}
189
 
190
+ @app.callback(
191
+ Output("download-transcript", "data"),
192
+ Input("download-button", "n_clicks"),
193
+ State("transcription-output", "children"),
194
+ prevent_initial_call=True
195
+ )
196
+ def download_transcript(n_clicks, transcription_output):
197
+ if not transcription_output:
198
+ raise PreventUpdate
199
+
200
+ transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
201
+ return dict(content=transcript, filename="transcript.txt")
202
 
203
  if __name__ == '__main__':
204
  print("Starting the Dash application...")