bluenevus commited on
Commit
249a3c0
·
verified ·
1 Parent(s): 6cffe75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -108
app.py CHANGED
@@ -1,132 +1,53 @@
1
- import gradio as gr
2
- import torchaudio
3
- import torchaudio.transforms as T
4
- from transformers import pipeline
5
- import requests
6
- from pydub import AudioSegment
7
- from pydub.silence import split_on_silence
8
- import io
9
- import os
10
- from bs4 import BeautifulSoup
11
- import re
12
- import numpy as np
13
- from moviepy.video.io.VideoFileClip import VideoFileClip
14
- import soundfile as sf
15
- from spellchecker import SpellChecker
16
- import tempfile
17
 
18
- # Load the transcription model
19
- transcription_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
20
- spell = SpellChecker()
21
 
22
- def download_audio_from_url(url):
23
- try:
24
- if "share" in url:
25
- print("Processing shareable link...")
26
- response = requests.get(url)
27
- soup = BeautifulSoup(response.content, 'html.parser')
28
- video_tag = soup.find('video')
29
- if video_tag and 'src' in video_tag.attrs:
30
- video_url = video_tag['src']
31
- print(f"Extracted video URL: {video_url}")
32
- else:
33
- raise ValueError("Direct video URL not found in the shareable link.")
34
- else:
35
- video_url = url
36
-
37
- print(f"Downloading video from URL: {video_url}")
38
- response = requests.get(video_url)
39
- audio_bytes = response.content
40
- print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
41
- return audio_bytes
42
- except Exception as e:
43
- print(f"Error in download_audio_from_url: {str(e)}")
44
- raise
45
-
46
- def correct_spelling(text):
47
- words = text.split()
48
- corrected_words = [spell.correction(word) or word for word in words]
49
- return ' '.join(corrected_words)
50
 
51
- def format_transcript(transcript):
52
- sentences = transcript.split('.')
53
- formatted_transcript = []
54
- current_speaker = None
55
- for sentence in sentences:
56
- if ':' in sentence:
57
- speaker, content = sentence.split(':', 1)
58
- if speaker != current_speaker:
59
- formatted_transcript.append(f"\n\n{speaker.strip()}:{content.strip()}.")
60
- current_speaker = speaker
61
- else:
62
- formatted_transcript.append(f"{content.strip()}.")
63
- else:
64
- formatted_transcript.append(sentence.strip() + '.')
65
- return ' '.join(formatted_transcript)
66
-
67
- def transcribe_audio(video_bytes):
68
  try:
69
- with open("temp_video.mp4", "wb") as f:
70
- f.write(video_bytes)
71
-
72
- video = VideoFileClip("temp_video.mp4")
73
- audio = video.audio
74
-
75
- audio.write_audiofile("temp_audio.wav", fps=16000, nbytes=2, codec='pcm_s16le')
76
-
77
- audio_data, sample_rate = sf.read("temp_audio.wav")
78
 
79
- if len(audio_data.shape) > 1:
80
- audio_data = audio_data.mean(axis=1)
81
 
82
- audio_data = audio_data.astype(np.float32) / np.max(np.abs(audio_data))
 
83
 
84
- result = transcription_pipeline(audio_data)
85
- transcript = result['text']
86
-
87
- transcript = correct_spelling(transcript)
88
- transcript = format_transcript(transcript)
89
-
90
- os.remove("temp_video.mp4")
91
- os.remove("temp_audio.wav")
92
-
93
- return transcript
94
  except Exception as e:
95
  print(f"Error in transcribe_audio: {str(e)}")
96
  raise
97
 
 
98
  def transcribe_video(url):
99
  try:
100
  print(f"Attempting to download audio from URL: {url}")
101
  audio_bytes = download_audio_from_url(url)
102
  print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
103
 
 
 
 
 
 
104
  print("Starting audio transcription...")
105
- transcript = transcribe_audio(audio_bytes)
106
  print("Transcription completed successfully")
107
 
 
 
 
108
  return transcript
109
  except Exception as e:
110
  error_message = f"An error occurred: {str(e)}"
111
  print(error_message)
112
- return error_message
113
-
114
- def download_transcript(transcript):
115
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as temp_file:
116
- temp_file.write(transcript)
117
- temp_file_path = temp_file.name
118
- return temp_file_path
119
-
120
- # Create the Gradio interface
121
- with gr.Blocks(title="Video Transcription") as demo:
122
- gr.Markdown("# Video Transcription")
123
- video_url = gr.Textbox(label="Video URL")
124
- transcribe_button = gr.Button("Transcribe")
125
- transcript_output = gr.Textbox(label="Transcript", lines=20)
126
- download_button = gr.Button("Download Transcript")
127
- download_link = gr.File(label="Download Transcript")
128
-
129
- transcribe_button.click(fn=transcribe_video, inputs=video_url, outputs=transcript_output)
130
- download_button.click(fn=download_transcript, inputs=transcript_output, outputs=download_link)
131
-
132
- demo.launch()
 
1
+ import torch
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # Check if CUDA is available and set the device
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+ print(f"Using device: {device}")
7
 
8
+ # Load the Whisper model and processor
9
+ model_name = "openai/whisper-base"
10
+ processor = WhisperProcessor.from_pretrained(model_name)
11
+ model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def transcribe_audio(audio_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
+ # Load and preprocess the audio
16
+ audio_input, sample_rate = sf.read(audio_file)
17
+ input_features = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_features.to(device)
 
 
 
 
 
 
18
 
19
+ # Generate token ids
20
+ predicted_ids = model.generate(input_features)
21
 
22
+ # Decode token ids to text
23
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
24
 
25
+ return transcription[0]
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
  print(f"Error in transcribe_audio: {str(e)}")
28
  raise
29
 
30
+ # Update the transcribe_video function to use the new transcribe_audio function
31
  def transcribe_video(url):
32
  try:
33
  print(f"Attempting to download audio from URL: {url}")
34
  audio_bytes = download_audio_from_url(url)
35
  print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
36
 
37
+ # Save audio bytes to a temporary file
38
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
39
+ temp_audio.write(audio_bytes)
40
+ temp_audio_path = temp_audio.name
41
+
42
  print("Starting audio transcription...")
43
+ transcript = transcribe_audio(temp_audio_path)
44
  print("Transcription completed successfully")
45
 
46
+ # Clean up the temporary file
47
+ os.unlink(temp_audio_path)
48
+
49
  return transcript
50
  except Exception as e:
51
  error_message = f"An error occurred: {str(e)}"
52
  print(error_message)
53
+ return error_message