bluenevus's picture
Update app.py
42ed1b1 verified
raw
history blame
7.42 kB
import io
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import requests
from bs4 import BeautifulSoup
import tempfile
import os
from pydub import AudioSegment
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import threading
from pytube import YouTube
print("Script started")
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name, torch_dtype=torch.float16).to(device)
# Load the Qwen model and tokenizer
qwen_model_name = "Qwen/Qwen2.5-3B-Instruct"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype=torch.float16).to(device)
def download_audio_from_url(url):
try:
if "youtube.com" in url or "youtu.be" in url:
print("Processing YouTube URL...")
yt = YouTube(url)
audio_stream = yt.streams.filter(only_audio=True).first()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
audio_stream.download(output_path=temp_file.name)
audio_bytes = open(temp_file.name, "rb").read()
os.unlink(temp_file.name)
elif "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
response = requests.get(video_url)
audio_bytes = response.content
else:
print(f"Downloading video from URL: {url}")
response = requests.get(url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def transcribe_audio(audio_file):
try:
print("Loading audio file...")
audio = AudioSegment.from_file(audio_file)
audio = audio.set_channels(1).set_frame_rate(16000)
audio_array = audio.get_array_of_samples()
print("Starting transcription...")
input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(f"Transcription complete. Length: {len(transcription[0])} characters")
return transcription[0]
except Exception as e:
print(f"Error in transcribe_audio: {str(e)}")
raise
def separate_speakers(transcription):
prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows:
1. Label speakers as "Speaker 1", "Speaker 2", etc.
2. Start each speaker's text on a new line beginning with their label.
3. Separate different speakers' contributions with a blank line.
4. If the same speaker continues, do not insert a blank line or repeat the speaker label.
Now, please process the following transcribed text:
{transcription}
"""
inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device)
outputs = qwen_model.generate(**inputs, max_new_tokens=1000)
result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the processed text (remove the instruction part)
processed_text = result.split("Now, please process the following transcribed text:")[-1].strip()
return processed_text
def transcribe_video(url):
try:
print(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav")
transcript = transcribe_audio(temp_audio.name)
os.unlink(temp_audio.name)
print("Separating speakers...")
separated_transcript = separate_speakers(transcript)
return separated_transcript
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return error_message
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"),
dbc.Card([
dbc.CardBody([
dbc.Input(id="video-url", type="text", placeholder="Enter video URL"),
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"),
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")),
dcc.Download(id="download-transcript")
])
])
], width=12)
])
], fluid=True)
@app.callback(
Output("transcription-output", "children"),
Output("download-transcript", "data"),
Input("transcribe-button", "n_clicks"),
State("video-url", "value"),
prevent_initial_call=True
)
def update_transcription(n_clicks, url):
if not url:
raise PreventUpdate
def transcribe():
try:
transcript = transcribe_video(url)
return transcript
except Exception as e:
return f"An error occurred: {str(e)}"
# Run transcription in a separate thread
thread = threading.Thread(target=transcribe)
thread.start()
thread.join()
transcript = thread.result if hasattr(thread, 'result') else "Transcription failed"
if transcript and not transcript.startswith("An error occurred"):
download_data = dict(content=transcript, filename="transcript.txt")
return dbc.Card([
dbc.CardBody([
html.H5("Transcription Result with Speaker Separation"),
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}),
dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3")
])
]), download_data
else:
return transcript, None
print("Reached end of script definitions")
if __name__ == '__main__':
print("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860)
print("Dash application has finished running.")