|
import io |
|
import os |
|
import tempfile |
|
import threading |
|
import base64 |
|
import logging |
|
from urllib.parse import urlparse |
|
|
|
import dash |
|
from dash import dcc, html, Input, Output, State, callback_context |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
|
|
import requests |
|
from pytube import YouTube |
|
from pydub import AudioSegment |
|
import openai |
|
|
|
|
|
try: |
|
from moviepy.editor import VideoFileClip |
|
except ImportError: |
|
try: |
|
import moviepy.editor as mpy |
|
VideoFileClip = mpy.VideoFileClip |
|
except ImportError: |
|
try: |
|
import moviepy |
|
VideoFileClip = moviepy.VideoFileClip |
|
except ImportError: |
|
logging.error("Failed to import VideoFileClip from moviepy. Please check the installation.") |
|
VideoFileClip = None |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
|
|
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") |
|
if not OPENAI_API_KEY: |
|
logger.error("OPENAI_API_KEY not found in environment variables") |
|
raise ValueError("OPENAI_API_KEY not set") |
|
|
|
openai.api_key = OPENAI_API_KEY |
|
|
|
|
|
|
|
def process_media(contents, filename, url): |
|
logger.info("Starting media processing") |
|
try: |
|
if contents: |
|
content_type, content_string = contents.split(',') |
|
decoded = base64.b64decode(content_string) |
|
suffix = os.path.splitext(filename)[1] |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: |
|
temp_file.write(decoded) |
|
temp_file_path = temp_file.name |
|
logger.info(f"File uploaded: {temp_file_path}") |
|
elif url: |
|
temp_file_path = download_media(url) |
|
else: |
|
logger.error("No input provided") |
|
raise ValueError("No input provided") |
|
|
|
if temp_file_path.lower().endswith(('.mp4', '.avi', '.mov', '.flv', '.wmv')): |
|
logger.info("Video file detected, extracting audio") |
|
audio_file_path = extract_audio(temp_file_path) |
|
transcript = transcribe_audio(audio_file_path) |
|
os.unlink(audio_file_path) |
|
else: |
|
logger.info("Audio file detected, transcribing directly") |
|
transcript = transcribe_audio(temp_file_path) |
|
|
|
os.unlink(temp_file_path) |
|
return transcript |
|
except Exception as e: |
|
logger.error(f"Error in process_media: {str(e)}") |
|
raise |
|
|
|
app.layout = dbc.Container([ |
|
dbc.Row([ |
|
dbc.Col([ |
|
html.H1("Audio/Video Transcription App", className="text-center my-4"), |
|
]) |
|
]), |
|
dbc.Row([ |
|
dbc.Col([ |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dcc.Upload( |
|
id='upload-media', |
|
children=html.Div([ |
|
'Drag and Drop or ', |
|
html.A('Select Audio/Video File') |
|
]), |
|
style={ |
|
'width': '100%', |
|
'height': '60px', |
|
'lineHeight': '60px', |
|
'borderWidth': '1px', |
|
'borderStyle': 'dashed', |
|
'borderRadius': '5px', |
|
'textAlign': 'center', |
|
'margin': '10px' |
|
}, |
|
multiple=False |
|
), |
|
html.Div(id='file-info', className="mt-2"), |
|
dbc.Input(id="media-url", type="text", placeholder="Enter audio/video URL or YouTube link", className="my-3"), |
|
dbc.Button("Transcribe", id="transcribe-button", color="primary", className="w-100 mb-3"), |
|
dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), |
|
html.Div(id="progress-indicator", className="text-center mt-3"), |
|
dbc.Button("Download Transcript", id="download-button", color="secondary", className="w-100 mt-3", style={'display': 'none'}), |
|
dcc.Download(id="download-transcript"), |
|
dcc.Store(id="transcription-store"), |
|
dcc.Interval(id='progress-interval', interval=500, n_intervals=0, disabled=True) |
|
]) |
|
]) |
|
], width=12) |
|
]) |
|
], fluid=True) |
|
|
|
@app.callback( |
|
Output("file-info", "children"), |
|
Input("upload-media", "filename"), |
|
Input("upload-media", "last_modified") |
|
) |
|
def update_file_info(filename, last_modified): |
|
if filename is not None: |
|
return f"File uploaded: {filename}" |
|
return "" |
|
|
|
@app.callback( |
|
Output("transcription-output", "children"), |
|
Output("download-button", "style"), |
|
Output("progress-indicator", "children"), |
|
Output("progress-interval", "disabled"), |
|
Output("transcription-store", "data"), |
|
Input("transcribe-button", "n_clicks"), |
|
Input("progress-interval", "n_intervals"), |
|
State("upload-media", "contents"), |
|
State("upload-media", "filename"), |
|
State("media-url", "value"), |
|
State("transcription-store", "data"), |
|
prevent_initial_call=True |
|
) |
|
def update_transcription(n_clicks, n_intervals, contents, filename, url, stored_transcript): |
|
ctx = callback_context |
|
if ctx.triggered_id == "transcribe-button": |
|
if not contents and not url: |
|
raise PreventUpdate |
|
|
|
def transcribe(): |
|
try: |
|
return process_media(contents, filename, url) |
|
except Exception as e: |
|
logger.error(f"Transcription failed: {str(e)}") |
|
return f"An error occurred: {str(e)}" |
|
|
|
thread = threading.Thread(target=transcribe) |
|
thread.start() |
|
return html.Div("Processing..."), {'display': 'none'}, "", False, None |
|
|
|
elif ctx.triggered_id == "progress-interval": |
|
if stored_transcript: |
|
return display_transcript(stored_transcript), {'display': 'block'}, "", True, stored_transcript |
|
dots = "." * (n_intervals % 4) |
|
return html.Div("Processing" + dots), {'display': 'none'}, "", False, None |
|
|
|
thread = threading.current_thread() |
|
if hasattr(thread, 'result'): |
|
transcript = thread.result |
|
if transcript and not transcript.startswith("An error occurred"): |
|
logger.info("Transcription successful") |
|
return display_transcript(transcript), {'display': 'block'}, "", True, transcript |
|
else: |
|
logger.error(f"Transcription failed: {transcript}") |
|
return html.Div(transcript), {'display': 'none'}, "", True, None |
|
|
|
return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update |
|
|
|
def display_transcript(transcript): |
|
return dbc.Card([ |
|
dbc.CardBody([ |
|
html.H5("Transcription Result"), |
|
html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}) |
|
]) |
|
]) |
|
|
|
@app.callback( |
|
Output("download-transcript", "data"), |
|
Input("download-button", "n_clicks"), |
|
State("transcription-store", "data"), |
|
prevent_initial_call=True |
|
) |
|
def download_transcript(n_clicks, transcript): |
|
if not transcript: |
|
raise PreventUpdate |
|
return dict(content=transcript, filename="transcript.txt") |
|
|
|
if __name__ == '__main__': |
|
logger.info("Starting the Dash application...") |
|
app.run(debug=True, host='0.0.0.0', port=7860) |
|
logger.info("Dash application has finished running.") |