camparchimedes commited on
Commit
4012d3e
·
verified ·
1 Parent(s): be5abfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -15
app.py CHANGED
@@ -22,7 +22,7 @@ import warnings
22
  from pydub import AudioSegment
23
  import torch
24
  import torchaudio
25
- from transformers import pipeline, WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor
26
  from huggingface_hub import model_info
27
  import spacy
28
  import networkx as nx
@@ -68,31 +68,53 @@ generate_kwargs = {
68
  "forced_decoder_ids": None
69
  }
70
 
71
- # Initialize pipeline
72
- asr = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=device, torch_dtype=torch_dtype)
73
-
74
- # Transcribe audio
75
- def transcribe_audio(audio_file):
76
  if audio_file.endswith(".m4a"):
77
  audio_file = convert_to_wav(audio_file)
78
 
79
  start_time = time.time()
80
 
81
- # ASR pipeline on audio
82
- with torch.no_grad():
83
- text = asr(audio_file, chunk_length_s=30, generate_kwargs=generate_kwargs)["text"]
84
 
85
- output_time = time.time() - start_time
 
 
86
 
87
- # Load with torchaudio() for TRF
88
- waveform, sample_rate = torchaudio.load(audio_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Audio duration (in seconds)
91
  audio_duration = waveform.shape[1] / sample_rate
92
 
93
- # Find audio duration@pipeline's internal method
94
- #audio_duration = pipe.feature_extractor.sampling_rate * len(pipe.feature_extractor(audio_file)["input_features"][0]) / pipe.feature_extractor.sampling_rate
95
-
96
  # Real-time Factor (RTF)
97
  rtf = output_time / audio_duration
98
 
@@ -109,6 +131,7 @@ def transcribe_audio(audio_file):
109
 
110
  return text, result
111
 
 
112
  # Clean and preprocess text for summarization
113
  def clean_text(text):
114
  text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)
 
22
  from pydub import AudioSegment
23
  import torch
24
  import torchaudio
25
+ from transformers import WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor
26
  from huggingface_hub import model_info
27
  import spacy
28
  import networkx as nx
 
68
  "forced_decoder_ids": None
69
  }
70
 
71
+ # Transcribe
72
+ def transcribe_audio(audio_file, chunk_length_s=30):
 
 
 
73
  if audio_file.endswith(".m4a"):
74
  audio_file = convert_to_wav(audio_file)
75
 
76
  start_time = time.time()
77
 
78
+ # Load the audio waveform using torchaudio
79
+ waveform, sample_rate = torchaudio.load(audio_file)
 
80
 
81
+ # Calculate the number of chunks
82
+ chunk_size = chunk_length_s * sample_rate
83
+ num_chunks = waveform.shape[1] // chunk_size + int(waveform.shape[1] % chunk_size != 0)
84
 
85
+ # Initialize an empty list to store the transcribed text from each chunk
86
+ full_text = []
87
+
88
+ for i in range(num_chunks):
89
+ start = i * chunk_size
90
+ end = min((i + 1) * chunk_size, waveform.shape[1])
91
+ chunk_waveform = waveform[:, start:end]
92
+
93
+ # Process the chunk
94
+ audio_input = processor(chunk_waveform, sampling_rate=sample_rate, return_tensors="pt")
95
+
96
+ # Generate attention mask
97
+ input_features = audio_input.input_features
98
+ attention_mask = torch.ones(input_features.shape, dtype=torch.long)
99
+
100
+ # ASR model inference on the chunk
101
+ with torch.no_grad():
102
+ generated_ids = model.generate(
103
+ input_features=input_features.to(device),
104
+ attention_mask=attention_mask.to(device),
105
+ **generate_kwargs
106
+ )
107
+ chunk_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
108
+ full_text.append(chunk_text)
109
+
110
+ # Combine the transcribed text from all chunks
111
+ text = " ".join(full_text)
112
+
113
+ output_time = time.time() - start_time
114
 
115
  # Audio duration (in seconds)
116
  audio_duration = waveform.shape[1] / sample_rate
117
 
 
 
 
118
  # Real-time Factor (RTF)
119
  rtf = output_time / audio_duration
120
 
 
131
 
132
  return text, result
133
 
134
+
135
  # Clean and preprocess text for summarization
136
  def clean_text(text):
137
  text = re.sub(r'https?:\/\/.*[\r\n]*', '', text)