ricardo-lsantos's picture
Fixed temp mp3 filename
298f1dc
import streamlit as st
import models.file as mf
import models.video as mv
import models.whisper as mw
import models.subtitles as ms
import models.transcript as mt
def sidebar():
device = st.sidebar.selectbox('Select Device',('CPU','GPU'))
st.sidebar.write('You selected:', device)
return device
def load_model():
# if pipeline is in session state, return it
# else, load it and save it to session state
if 'pipeline' not in st.session_state:
device = mw.get_device()
st.session_state['pipeline'] = mw.get_pipe(device)
return st.session_state['pipeline']
def app():
device = sidebar()
pipeline = load_model()
st.title('Transcript Small Whisper')
st.write('Welcome to the Home page!')
file = st.file_uploader("Upload Files",type=['mp4','wav','mp3'])
if file is not None:
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text('Uploading file...')
progress_bar.progress(10)
st.write(file.name)
status_text.text('File uploaded!')
file_details = {"FileName":file.name,"FileType":file.type,"FileSize":file.size}
st.write(file_details)
# st.write("type of file: ", type(file))
# st.write("Dir: ", dir(file))
# st.write("File: ", file.read())
if mf.get_file_type(file) == 'video':
status_text.text('Extracting audio from video...')
audio = mv.get_audio_from_video(file, file.name.replace(".mp4",".mp3"))
if mf.get_file_type(file) == 'audio':
status_text.text('Extracting audio from audio...')
audio = file.read()
progress_bar.progress(30)
status_text.text('Transcribing audio...')
transcript = mw.get_prediction(pipeline, audio)
progress_bar.progress(60)
status_text.text('Subtitling audio...')
subtitles = mw.get_prediction_with_timelines(transcript, file.name + '.srt')
progress_bar.progress(90)
status_text.text('Saving subtitles...')
st.session_state['subtitles'] = subtitles
ms.save_subtitles(subtitles, file.name + '.srt')
status_text.text('Saving transcript...')
st.session_state['transcript'] = transcript
mt.save_transcript(transcript, file.name + '.txt')
status_text.text('Done!')
progress_bar.progress(100)
st.download_button(
label="Download Transcript",
data=st.session_state['transcript'],
file_name=file.name + '.txt',
mime='text/plain'
)
st.download_button(
label="Download Subtitles",
data=st.session_state['subtitles'],
file_name=file.name + '.srt',
mime='text/plain'
)
if __name__ == '__main__':
app()