import io import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM import requests from bs4 import BeautifulSoup import tempfile import os from pydub import AudioSegment import dash from dash import dcc, html, Input, Output, State import dash_bootstrap_components as dbc from dash.exceptions import PreventUpdate import threading from pytube import YouTube print("Script started") # Check if CUDA is available and set the device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the Whisper model and processor whisper_model_name = "openai/whisper-small.en" whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device) # Load the Qwen model and tokenizer qwen_model_name = "Qwen/Qwen2.5-3B-Instruct" qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True) qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True, torch_dtype=torch.float16).to(device) def download_audio_from_url(url): try: if "youtube.com" in url or "youtu.be" in url: print("Processing YouTube URL...") yt = YouTube(url) audio_stream = yt.streams.filter(only_audio=True).first() with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: audio_stream.download(output_path=temp_file.name) audio_bytes = open(temp_file.name, "rb").read() os.unlink(temp_file.name) elif "share" in url: print("Processing shareable link...") response = requests.get(url) soup = BeautifulSoup(response.content, 'html.parser') video_tag = soup.find('video') if video_tag and 'src' in video_tag.attrs: video_url = video_tag['src'] print(f"Extracted video URL: {video_url}") else: raise ValueError("Direct video URL not found in the shareable link.") response = requests.get(video_url) audio_bytes = response.content else: print(f"Downloading video from URL: {url}") response = requests.get(url) audio_bytes = response.content print(f"Successfully downloaded {len(audio_bytes)} bytes of data") return audio_bytes except Exception as e: print(f"Error in download_audio_from_url: {str(e)}") raise def transcribe_audio(audio_file): try: print("Loading audio file...") audio = AudioSegment.from_file(audio_file) audio = audio.set_channels(1).set_frame_rate(16000) audio_array = torch.tensor(audio.get_array_of_samples()).float() print(f"Audio duration: {len(audio) / 1000:.2f} seconds") print("Starting transcription...") input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device) predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) print(f"Transcription complete. Length: {len(transcription[0])} characters") return transcription[0] except Exception as e: print(f"Error in transcribe_audio: {str(e)}") raise def separate_speakers(transcription): print("Starting speaker separation...") prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows: 1. Label speakers as "Speaker 1", "Speaker 2", etc. You will have to use dialog context to asume which speaker is saying their dialog as that isn't in the text. 2. Start each speaker's text on a new line beginning with their label. 3. Separate different speakers' contributions with a blank line. 4. If the same speaker continues, do not insert a blank line or repeat the speaker label. Now, please process the following transcribed text: {transcription} """ inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device) inputs = {k: v.to(torch.float16) for k, v in inputs.items()} # Convert inputs to float16 with torch.no_grad(): outputs = qwen_model.generate(**inputs, max_new_tokens=4000) result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the processed text (remove the instruction part) processed_text = result.split("Now, please process the following transcribed text:")[-1].strip() print("Speaker separation complete.") return processed_text def transcribe_video(url): try: print(f"Attempting to download audio from URL: {url}") audio_bytes = download_audio_from_url(url) print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav") transcript = transcribe_audio(temp_audio.name) os.unlink(temp_audio.name) print("Separating speakers...") separated_transcript = separate_speakers(transcript) return separated_transcript except Exception as e: error_message = f"An error occurred: {str(e)}" print(error_message) return error_message app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) app.layout = dbc.Container([ dbc.Row([ dbc.Col([ html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"), dbc.Card([ dbc.CardBody([ dbc.Input(id="video-url", type="text", placeholder="Enter video URL"), dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"), dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), html.Div([ dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}), dcc.Download(id="download-transcript") ]) ]) ]) ], width=12) ]) ], fluid=True) @app.callback( Output("transcription-output", "children"), Output("download-button", "style"), Input("transcribe-button", "n_clicks"), State("video-url", "value"), prevent_initial_call=True ) def update_transcription(n_clicks, url): if not url: raise PreventUpdate def transcribe(): try: transcript = transcribe_video(url) return transcript except Exception as e: import traceback return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" # Run transcription in a separate thread thread = threading.Thread(target=transcribe) thread.start() thread.join(timeout=600) # 10 minutes timeout if thread.is_alive(): return "Transcription timed out after 10 minutes", {'display': 'none'} transcript = thread.result if hasattr(thread, 'result') else "Transcription failed" if transcript and not transcript.startswith("An error occurred"): return dbc.Card([ dbc.CardBody([ html.H5("Transcription Result with Speaker Separation"), html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}) ]) ]), {'display': 'block'} else: return transcript, {'display': 'none'} @app.callback( Output("download-transcript", "data"), Input("download-button", "n_clicks"), State("transcription-output", "children"), prevent_initial_call=True ) def download_transcript(n_clicks, transcription_output): if not transcription_output: raise PreventUpdate transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children'] return dict(content=transcript, filename="transcript.txt") if __name__ == '__main__': print("Starting the Dash application...") app.run(debug=True, host='0.0.0.0', port=7860) print("Dash application has finished running.")