File size: 8,638 Bytes
1100e65 edaf161 249a3c0 b09f327 53bdf99 b09f327 af532e7 7a13c00 1b493d6 170241f 53bdf99 8af57a0 6575bf4 077c90e 170241f 249a3c0 a18a113 249a3c0 b09f327 249a3c0 17ca647 b09f327 8af57a0 b09f327 8af57a0 b09f327 8af57a0 b09f327 59a6a31 b09f327 59a6a31 b09f327 836768f 8369f51 b9e0aa5 8369f51 f1f904a 8369f51 8af57a0 532f762 8369f51 752e0a6 8369f51 f1f904a 8af57a0 8369f51 a18a113 836768f 0cfb05e af532e7 256795b af532e7 249a3c0 af532e7 249a3c0 0cfb05e 836768f 256795b 0cfb05e 249a3c0 59a6a31 b09f327 0cfb05e 256795b b09f327 53bdf99 df42ab3 53bdf99 836768f df42ab3 b9e0aa5 836768f 53bdf99 b9e0aa5 53bdf99 836768f 53bdf99 8af57a0 53bdf99 6575bf4 a216ced 6575bf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import io
import re
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile
import os
import soundfile as sf
from spellchecker import SpellChecker
from pydub import AudioSegment
import librosa
import numpy as np
from pyannote.audio import Pipeline
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import base64
import threading
from pytube import YouTube
print("Script started")
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
spell = SpellChecker()
def download_audio_from_url(url):
try:
if "youtube.com" in url or "youtu.be" in url:
print("Processing YouTube URL...")
yt = YouTube(url)
audio_stream = yt.streams.filter(only_audio=True).first()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
audio_stream.download(output_path=temp_file.name)
audio_bytes = open(temp_file.name, "rb").read()
os.unlink(temp_file.name)
elif "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
response = requests.get(video_url)
audio_bytes = response.content
else:
print(f"Downloading video from URL: {url}")
response = requests.get(url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def correct_spelling(text):
words = text.split()
corrected_words = [spell.correction(word) or word for word in words]
return ' '.join(corrected_words)
def format_transcript_with_speakers(transcript, diarization):
formatted_transcript = []
current_speaker = None
for segment, _, speaker in diarization.itertracks(yield_label=True):
start = segment.start
end = segment.end
if speaker != current_speaker:
if current_speaker is not None:
formatted_transcript.append("\n") # Add a blank line between speakers
formatted_transcript.append(f"Speaker {speaker}:\n")
current_speaker = speaker
segment_text = transcript[start:end].strip()
if segment_text:
formatted_transcript.append(f"{segment_text}\n")
return "".join(formatted_transcript)
def transcribe_audio(audio_file, pipeline):
try:
if pipeline is None:
raise ValueError("Speaker diarization pipeline is not initialized")
print("Loading audio file...")
audio_input, sr = librosa.load(audio_file, sr=16000)
audio_input = audio_input.astype(np.float32)
print(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
# Apply speaker diarization
print("Applying speaker diarization...")
diarization = pipeline(audio_file)
print("Speaker diarization complete.")
chunk_length = 30 * sr
overlap = 5 * sr
transcriptions = []
print("Starting transcription...")
for i in range(0, len(audio_input), chunk_length - overlap):
chunk = audio_input[i:i+chunk_length]
input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcriptions.extend(transcription)
print(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
full_transcription = " ".join(transcriptions)
print(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
print("Applying formatting with speaker diarization...")
formatted_transcription = format_transcript_with_speakers(full_transcription, diarization)
return formatted_transcription
except Exception as e:
print(f"Error in transcribe_audio: {str(e)}")
raise
def transcribe_video(url, pipeline):
try:
print(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
# Convert audio bytes to AudioSegment
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
print(f"Audio duration: {len(audio) / 1000} seconds")
# Save as WAV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
audio.export(temp_audio.name, format="wav")
temp_audio_path = temp_audio.name
print("Starting audio transcription...")
transcript = transcribe_audio(temp_audio_path, pipeline)
print(f"Transcription completed. Transcript length: {len(transcript)} characters")
# Clean up the temporary file
os.unlink(temp_audio_path)
# Apply spelling correction
transcript = correct_spelling(transcript)
return transcript
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return error_message
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Video Transcription", className="text-center mb-4"),
dbc.Card([
dbc.CardBody([
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
dcc.Download(id="download-transcript")
])
])
], width=12)
])
], fluid=True)
@app.callback(
Output("transcription-output", "children"),
Output("download-transcript", "data"),
Input("transcribe-button", "n_clicks"),
State("video-url", "value"),
prevent_initial_call=True
)
def update_transcription(n_clicks, url):
if not url:
raise PreventUpdate
def transcribe():
try:
# Initialize the speaker diarization pipeline without token
pipeline = Pipeline.from_pretrained("collinbarnwell/pyannote-speaker-diarization-31")
if pipeline is None:
raise ValueError("Failed to initialize the speaker diarization pipeline")
print("Speaker diarization pipeline initialized successfully")
transcript = transcribe_video(url, pipeline)
return transcript
except Exception as e:
return f"An error occurred: {str(e)}"
# Run transcription in a separate thread
thread = threading.Thread(target=transcribe)
thread.start()
thread.join()
transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"
if transcript and not transcript.startswith("An error occurred"):
download_data = dict(content=transcript, filename="transcript.txt")
return dbc.Card([
dbc.CardBody([
html.H5("Transcription Result"),
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
])
]), download_data
else:
return transcript, None
if __name__ == '__main__':
print("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
print("Dash application has finished running.") |