camparchimedes commited on
Commit
cb06cac
·
verified ·
1 Parent(s): a086817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -17
app.py CHANGED
@@ -108,24 +108,13 @@ def transcribe_audio(audio_file, chunk_length_s=30):
108
  inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
109
  input_features = inputs.input_features
110
 
111
- # Create attention mask
112
- attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
113
-
114
- # ASR model inference on the chunk
115
- with torch.no_grad():
116
- generated_ids = model.generate(
117
- input_features=input_features.to(device),
118
- attention_mask=attention_mask.to(device),
119
- **generate_kwargs
120
- )
121
- # Process the chunk with the tokenizer
122
- inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
123
-
124
- input_features = inputs.input_features
125
-
126
  # Create attention mask
127
  attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
128
-
 
 
 
 
129
  # ASR model inference on the chunk
130
  with torch.no_grad():
131
  generated_ids = model.generate(
@@ -161,9 +150,9 @@ with torch.no_grad():
161
  "An RTF of less than 1 means the transcription process is faster than real-time (expected)."
162
  )
163
 
164
-
165
  return text, result
166
 
 
167
  # Clean and preprocess/@summarization
168
  def clean_text(text):
169
  text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
 
108
  inputs = processor(chunk_waveform.squeeze(0).numpy(), sampling_rate=sample_rate, return_tensors="pt")
109
  input_features = inputs.input_features
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Create attention mask
112
  attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long, device=device)
113
+
114
+ # Check the dimensions and values of the attention mask
115
+ assert attention_mask.shape == (1, input_features.shape[1]), "Attention mask dimensions do not match the input features."
116
+ assert (attention_mask.sum().item() == input_features.shape[1]), "Attention mask has incorrect values."
117
+
118
  # ASR model inference on the chunk
119
  with torch.no_grad():
120
  generated_ids = model.generate(
 
150
  "An RTF of less than 1 means the transcription process is faster than real-time (expected)."
151
  )
152
 
 
153
  return text, result
154
 
155
+
156
  # Clean and preprocess/@summarization
157
  def clean_text(text):
158
  text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)