Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22 |
from pydub import AudioSegment
|
23 |
import torch
|
24 |
import torchaudio
|
25 |
-
from transformers import
|
26 |
from huggingface_hub import model_info
|
27 |
import spacy
|
28 |
import networkx as nx
|
@@ -68,31 +68,53 @@ generate_kwargs = {
|
|
68 |
"forced_decoder_ids": None
|
69 |
}
|
70 |
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
# Transcribe audio
|
75 |
-
def transcribe_audio(audio_file):
|
76 |
if audio_file.endswith(".m4a"):
|
77 |
audio_file = convert_to_wav(audio_file)
|
78 |
|
79 |
start_time = time.time()
|
80 |
|
81 |
-
#
|
82 |
-
|
83 |
-
text = asr(audio_file, chunk_length_s=30, generate_kwargs=generate_kwargs)["text"]
|
84 |
|
85 |
-
|
|
|
|
|
86 |
|
87 |
-
#
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Audio duration (in seconds)
|
91 |
audio_duration = waveform.shape[1] / sample_rate
|
92 |
|
93 |
-
# Find audio duration@pipeline's internal method
|
94 |
-
#audio_duration = pipe.feature_extractor.sampling_rate * len(pipe.feature_extractor(audio_file)["input_features"][0]) / pipe.feature_extractor.sampling_rate
|
95 |
-
|
96 |
# Real-time Factor (RTF)
|
97 |
rtf = output_time / audio_duration
|
98 |
|
@@ -109,6 +131,7 @@ def transcribe_audio(audio_file):
|
|
109 |
|
110 |
return text, result
|
111 |
|
|
|
112 |
# Clean and preprocess text for summarization
|
113 |
def clean_text(text):
|
114 |
text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
|
|
|
22 |
from pydub import AudioSegment
|
23 |
import torch
|
24 |
import torchaudio
|
25 |
+
from transformers import WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor
|
26 |
from huggingface_hub import model_info
|
27 |
import spacy
|
28 |
import networkx as nx
|
|
|
68 |
"forced_decoder_ids": None
|
69 |
}
|
70 |
|
71 |
+
# Transcribe
|
72 |
+
def transcribe_audio(audio_file, chunk_length_s=30):
|
|
|
|
|
|
|
73 |
if audio_file.endswith(".m4a"):
|
74 |
audio_file = convert_to_wav(audio_file)
|
75 |
|
76 |
start_time = time.time()
|
77 |
|
78 |
+
# Load the audio waveform using torchaudio
|
79 |
+
waveform, sample_rate = torchaudio.load(audio_file)
|
|
|
80 |
|
81 |
+
# Calculate the number of chunks
|
82 |
+
chunk_size = chunk_length_s * sample_rate
|
83 |
+
num_chunks = waveform.shape[1] // chunk_size + int(waveform.shape[1] % chunk_size != 0)
|
84 |
|
85 |
+
# Initialize an empty list to store the transcribed text from each chunk
|
86 |
+
full_text = []
|
87 |
+
|
88 |
+
for i in range(num_chunks):
|
89 |
+
start = i * chunk_size
|
90 |
+
end = min((i + 1) * chunk_size, waveform.shape[1])
|
91 |
+
chunk_waveform = waveform[:, start:end]
|
92 |
+
|
93 |
+
# Process the chunk
|
94 |
+
audio_input = processor(chunk_waveform, sampling_rate=sample_rate, return_tensors="pt")
|
95 |
+
|
96 |
+
# Generate attention mask
|
97 |
+
input_features = audio_input.input_features
|
98 |
+
attention_mask = torch.ones(input_features.shape, dtype=torch.long)
|
99 |
+
|
100 |
+
# ASR model inference on the chunk
|
101 |
+
with torch.no_grad():
|
102 |
+
generated_ids = model.generate(
|
103 |
+
input_features=input_features.to(device),
|
104 |
+
attention_mask=attention_mask.to(device),
|
105 |
+
**generate_kwargs
|
106 |
+
)
|
107 |
+
chunk_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
108 |
+
full_text.append(chunk_text)
|
109 |
+
|
110 |
+
# Combine the transcribed text from all chunks
|
111 |
+
text = " ".join(full_text)
|
112 |
+
|
113 |
+
output_time = time.time() - start_time
|
114 |
|
115 |
# Audio duration (in seconds)
|
116 |
audio_duration = waveform.shape[1] / sample_rate
|
117 |
|
|
|
|
|
|
|
118 |
# Real-time Factor (RTF)
|
119 |
rtf = output_time / audio_duration
|
120 |
|
|
|
131 |
|
132 |
return text, result
|
133 |
|
134 |
+
|
135 |
# Clean and preprocess text for summarization
|
136 |
def clean_text(text):
|
137 |
text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
|