File size: 9,151 Bytes
1100e65 249a3c0 fce37ea b09f327 53bdf99 b09f327 af532e7 53bdf99 8af57a0 5ee7955 a62f407 5ee7955 6575bf4 077c90e 170241f 249a3c0 a18a113 249a3c0 261d49a fce37ea 26cf8bb fce37ea 1df2592 26cf8bb 54c226c 17ca647 9a1f744 ffd5e97 a80c887 9a1f744 ffd5e97 a80c887 ffd5e97 a80c887 9a1f744 ffd5e97 9a1f744 b09f327 53bdf99 fce37ea 5ee7955 53bdf99 26cf8bb 53bdf99 26cf8bb 53bdf99 df42ab3 53bdf99 836768f 81f702f 836768f 5ee7955 53bdf99 26cf8bb 53bdf99 da7b836 53bdf99 836768f 53bdf99 fce37ea 26cf8bb 53bdf99 26cf8bb 53bdf99 26cf8bb 81f702f 26cf8bb 81f702f 6575bf4 5ee7955 31b9df5 5ee7955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
import logging
import librosa
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
print("Script started")
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device)
# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device)
def download_audio_from_url(url):
try:
if "youtube.com" in url or "youtu.be" in url:
print("Processing YouTube URL...")
yt = YouTube(url)
audio_stream = yt.streams.filter(only_audio=True).first()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
audio_stream.download(output_path=temp_file.name)
audio_bytes = open(temp_file.name, "rb").read()
os.unlink(temp_file.name)
elif "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
response = requests.get(video_url)
audio_bytes = response.content
else:
print(f"Downloading video from URL: {url}")
response = requests.get(url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def transcribe_audio(audio_file):
try:
logger.info("Loading audio file...")
audio_input, sr = librosa.load(audio_file, sr=16000)
audio_input = audio_input.astype(np.float32)
logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
chunk_length = 30 * sr
overlap = 5 * sr
transcriptions = []
logger.info("Starting transcription...")
for i in range(0, len(audio_input), chunk_length - overlap):
chunk = audio_input[i:i+chunk_length]
input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcriptions.extend(transcription)
logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
full_transcription = " ".join(transcriptions)
logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
logger.info("Applying speaker separation using Qwen...")
separated_transcript = separate_speakers(full_transcription)
return separated_transcript
except Exception as e:
logger.error(f"Error in transcribe_audio: {str(e)}")
raise
def separate_speakers(transcription):
print("Starting speaker separation...")
prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:
1. Label speakers as "Speaker 1", "Speaker 2", etc.
2. Start each speaker's text on a new line beginning with their label.
3. Separate different speakers' contributions with a blank line.
4. If the same speaker continues, do not insert a blank line or repeat the speaker label.
Now, please process the following transcribed text:
{transcription}
"""
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = qwen_model.generate(**inputs, max_new_tokens=4000)
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the processed text (remove the instruction part)
processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
print("Speaker separation complete.")
return processed_text
def transcribe_video(url):
try:
print(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
transcript = transcribe_audio(temp_audio.name)
os.unlink(temp_audio.name)
if len(transcript) < 10:
raise ValueError("Transcription too short, possibly failed")
print("Separating speakers...")
separated_transcript = separate_speakers(transcript)
return separated_transcript
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return error_message
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
html.Div("If you can see this, the app is working!", className="text-center mb-4"), # Debug element
dbc.Card([
dbc.CardBody([
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
html.Div([
dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}),
dcc.Download(id="download-transcript")
])
])
])
], width=12)
])
], fluid=True)
@app.callback(
Output("transcription-output", "children"),
Output("download-button", "style"),
Input("transcribe-button", "n_clicks"),
State("video-url", "value"),
prevent_initial_call=True
)
def update_transcription(n_clicks, url):
if not url:
raise PreventUpdate
def transcribe():
try:
transcript = transcribe_video(url)
return transcript
except Exception as e:
logger.exception("Error in transcription:")
return f"An error occurred: {str(e)}"
# Run transcription in a separate thread
thread = threading.Thread(target=transcribe)
thread.start()
thread.join(timeout=600) # 10 minutes timeout
if thread.is_alive():
return "Transcription timed out after 10 minutes", {'display': 'none'}
transcript = getattr(thread, 'result', "Transcription failed")
if transcript and not transcript.startswith("An error occurred"):
return dbc.Card([
dbc.CardBody([
html.H5("Transcription Result with Speaker Separation"),
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"})
])
]), {'display': 'block'}
else:
return transcript, {'display': 'none'}
@app.callback(
Output("download-transcript", "data"),
Input("download-button", "n_clicks"),
State("transcription-output", "children"),
prevent_initial_call=True
)
def download_transcript(n_clicks, transcription_output):
if not transcription_output:
raise PreventUpdate
transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children']
return dict(content=transcript, filename="transcript.txt")
if __name__ == '__main__':
logger.info("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
logger.info("Dash application has finished running.") |