|
import io |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM |
|
import requests |
|
from bs4 import BeautifulSoup |
|
import tempfile |
|
import os |
|
from pydub import AudioSegment |
|
import dash |
|
from dash import dcc, html, Input, Output, State |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
import threading |
|
from pytube import YouTube |
|
import logging |
|
import librosa |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
print("Script started") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
whisper_model_name = "openai/whisper-small" |
|
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device) |
|
|
|
|
|
qwen_model_name = "Qwen/Qwen2.5-1.5B-Instruct" |
|
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True) |
|
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device) |
|
|
|
def download_audio_from_url(url): |
|
try: |
|
if "youtube.com" in url or "youtu.be" in url: |
|
print("Processing YouTube URL...") |
|
yt = YouTube(url) |
|
audio_stream = yt.streams.filter(only_audio=True).first() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: |
|
audio_stream.download(output_path=temp_file.name) |
|
audio_bytes = open(temp_file.name, "rb").read() |
|
os.unlink(temp_file.name) |
|
elif "share" in url: |
|
print("Processing shareable link...") |
|
response = requests.get(url) |
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
video_tag = soup.find('video') |
|
if video_tag and 'src' in video_tag.attrs: |
|
video_url = video_tag['src'] |
|
print(f"Extracted video URL: {video_url}") |
|
else: |
|
raise ValueError("Direct video URL not found in the shareable link.") |
|
response = requests.get(video_url) |
|
audio_bytes = response.content |
|
else: |
|
print(f"Downloading video from URL: {url}") |
|
response = requests.get(url) |
|
audio_bytes = response.content |
|
|
|
print(f"Successfully downloaded {len(audio_bytes)} bytes of data") |
|
return audio_bytes |
|
except Exception as e: |
|
print(f"Error in download_audio_from_url: {str(e)}") |
|
raise |
|
|
|
def transcribe_audio(audio_file): |
|
try: |
|
logger.info("Loading audio file...") |
|
audio_input, sr = librosa.load(audio_file, sr=16000) |
|
audio_input = audio_input.astype(np.float32) |
|
logger.info(f"Audio duration: {len(audio_input) / sr:.2f} seconds") |
|
|
|
chunk_length = 30 * sr |
|
overlap = 5 * sr |
|
transcriptions = [] |
|
|
|
logger.info("Starting transcription...") |
|
for i in range(0, len(audio_input), chunk_length - overlap): |
|
chunk = audio_input[i:i+chunk_length] |
|
input_features = whisper_processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device) |
|
predicted_ids = whisper_model.generate(input_features) |
|
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
transcriptions.extend(transcription) |
|
logger.info(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds") |
|
|
|
full_transcription = " ".join(transcriptions) |
|
logger.info(f"Transcription complete. Full transcription length: {len(full_transcription)} characters") |
|
|
|
logger.info("Applying speaker separation using Qwen...") |
|
separated_transcript = separate_speakers(full_transcription) |
|
|
|
return separated_transcript |
|
except Exception as e: |
|
logger.error(f"Error in transcribe_audio: {str(e)}") |
|
raise |
|
|
|
def separate_speakers(transcription): |
|
logger.info("Starting speaker separation...") |
|
prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows: |
|
|
|
1. Label speakers as "Speaker 1", "Speaker 2", etc. |
|
2. Start each speaker's text on a new line beginning with their label. |
|
3. Separate different speakers' contributions with a blank line. |
|
4. If the same speaker continues, do not insert a blank line or repeat the speaker label. |
|
5. Do not include any additional explanations or metadata. |
|
|
|
Now, please process the following transcribed text: |
|
|
|
{transcription} |
|
""" |
|
|
|
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = qwen_model.generate(**inputs, max_new_tokens=4000) |
|
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
processed_text = result.split("Now, please process the following transcribed text:")[-1].strip() |
|
|
|
logger.info("Speaker separation complete.") |
|
return processed_text |
|
|
|
def transcribe_video(url): |
|
try: |
|
logger.info(f"Attempting to download audio from URL: {url}") |
|
audio_bytes = download_audio_from_url(url) |
|
logger.info(f"Successfully downloaded {len(audio_bytes)} bytes of audio data") |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav") |
|
transcript = transcribe_audio(temp_audio.name) |
|
|
|
os.unlink(temp_audio.name) |
|
|
|
if len(transcript) < 10: |
|
raise ValueError("Transcription too short, possibly failed") |
|
|
|
logger.info("Separating speakers...") |
|
try: |
|
diarized_transcript = separate_speakers(transcript) |
|
logger.info(f"Speaker separation complete. Result length: {len(diarized_transcript)} characters") |
|
if len(diarized_transcript) < 10: |
|
logger.warning("Speaker separation result too short, using original transcript") |
|
return transcript |
|
return diarized_transcript |
|
except Exception as e: |
|
logger.error(f"Error during speaker separation: {str(e)}") |
|
logger.info("Returning original transcript without speaker separation") |
|
return transcript |
|
except Exception as e: |
|
error_message = f"An error occurred: {str(e)}" |
|
logger.error(error_message) |
|
return error_message |
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
|
app.layout = dbc.Container([ |
|
dbc.Row([ |
|
dbc.Col([ |
|
html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"), |
|
html.Div("If you can see this, the app is working!", className="text-center mb-4"), |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"), |
|
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"), |
|
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), |
|
html.Div([ |
|
dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}), |
|
dcc.Download(id="download-transcript") |
|
]) |
|
]) |
|
]) |
|
], width=12) |
|
]) |
|
], fluid=True) |
|
|
|
@app.callback( |
|
Output("transcription-output", "children"), |
|
Output("download-button", "style"), |
|
Input("transcribe-button", "n_clicks"), |
|
State("video-url", "value"), |
|
prevent_initial_call=True |
|
) |
|
def update_transcription(n_clicks, url): |
|
if not url: |
|
raise PreventUpdate |
|
|
|
transcript = transcribe_video(url) |
|
|
|
if transcript and not transcript.startswith("An error occurred"): |
|
return dbc.Card([ |
|
dbc.CardBody([ |
|
html.H5("Transcription Result with Speaker Separation"), |
|
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}) |
|
]) |
|
]), {'display': 'block'} |
|
else: |
|
return transcript, {'display': 'none'} |
|
|
|
@app.callback( |
|
Output("download-transcript", "data"), |
|
Input("download-button", "n_clicks"), |
|
State("transcription-output", "children"), |
|
prevent_initial_call=True |
|
) |
|
|
|
def download_transcript(n_clicks, transcription_output): |
|
if not transcription_output: |
|
raise PreventUpdate |
|
|
|
transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children'] |
|
return dict(content=transcript, filename="transcript.txt") |
|
|
|
if __name__ == '__main__': |
|
logger.info("Starting the Dash application...") |
|
app.run(debug=True, host='0.0.0.0', port=7860) |
|
logger.info("Dash application has finished running.") |