Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -108,24 +108,13 @@ def transcribe_audio(audio_file, chunk_length_s=30):
|
|
108 |
inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
|
109 |
input_features = inputs.input_features
|
110 |
|
111 |
-
# Create attention mask
|
112 |
-
attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
|
113 |
-
|
114 |
-
# ASR model inference on the chunk
|
115 |
-
with torch.no_grad():
|
116 |
-
generated_ids = model.generate(
|
117 |
-
input_features=input_features.to(device),
|
118 |
-
attention_mask=attention_mask.to(device),
|
119 |
-
**generate_kwargs
|
120 |
-
)
|
121 |
-
# Process the chunk with the tokenizer
|
122 |
-
inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
|
123 |
-
|
124 |
-
input_features = inputs.input_features
|
125 |
-
|
126 |
# Create attention mask
|
127 |
attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
|
128 |
-
|
|
|
|
|
|
|
|
|
129 |
# ASR model inference on the chunk
|
130 |
with torch.no_grad():
|
131 |
generated_ids = model.generate(
|
@@ -161,9 +150,9 @@ with torch.no_grad():
|
|
161 |
"An RTF of less than 1 means the transcription process is faster than real-time (expected)."
|
162 |
)
|
163 |
|
164 |
-
|
165 |
return text, result
|
166 |
|
|
|
167 |
# Clean and preprocess/@summarization
|
168 |
def clean_text(text):
|
169 |
text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
|
|
|
108 |
inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
|
109 |
input_features = inputs.input_features
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# Create attention mask
|
112 |
attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
|
113 |
+
|
114 |
+
# Check the dimensions and values of the attention mask
|
115 |
+
assert attention_mask.shape == (1, input_features.shape[1]), "Attention mask dimensions do not match the input features."
|
116 |
+
assert (attention_mask.sum().item() == input_features.shape[1]), "Attention mask has incorrect values."
|
117 |
+
|
118 |
# ASR model inference on the chunk
|
119 |
with torch.no_grad():
|
120 |
generated_ids = model.generate(
|
|
|
150 |
"An RTF of less than 1 means the transcription process is faster than real-time (expected)."
|
151 |
)
|
152 |
|
|
|
153 |
return text, result
|
154 |
|
155 |
+
|
156 |
# Clean and preprocess/@summarization
|
157 |
def clean_text(text):
|
158 |
text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
|