meeting-minutes-ai / utils /speech_processor.py
Yermia's picture
Fix requ
e43a761
# import torch
# import torchaudio
# from transformers import (
# WhisperProcessor,
# WhisperForConditionalGeneration,
# pipeline
# )
# from pyannote.audio import Pipeline
# import librosa
# import numpy as np
# from pydub import AudioSegment
# import tempfile
# import os # ADD THIS LINE - FIX FOR THE ERROR
# class SpeechProcessor:
# def __init__(self):
# # Load Whisper for ASR
# self.whisper_processor = WhisperProcessor.from_pretrained(
# "openai/whisper-medium"
# )
# self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
# "openai/whisper-medium"
# )
# # Load speaker diarization
# try:
# self.diarization_pipeline = Pipeline.from_pretrained(
# "pyannote/speaker-diarization-3.1",
# use_auth_token=os.environ.get("HF_TOKEN") # Now os is imported
# )
# except Exception as e:
# print(f"Warning: Could not load diarization model: {e}")
# self.diarization_pipeline = None
# def process_audio(self, audio_path, language="id"):
# """
# Process audio file untuk ASR dan speaker diarization
# """
# # Convert to WAV if needed
# audio_path = self._ensure_wav_format(audio_path)
# # Load audio
# waveform, sample_rate = torchaudio.load(audio_path)
# # Speaker diarization
# if self.diarization_pipeline:
# try:
# diarization = self.diarization_pipeline(audio_path)
# # Process each speaker segment
# transcript_segments = []
# for turn, _, speaker in diarization.itertracks(yield_label=True):
# # Extract segment audio
# start_sample = int(turn.start * sample_rate)
# end_sample = int(turn.end * sample_rate)
# segment_waveform = waveform[:, start_sample:end_sample]
# # ASR on segment
# text = self._transcribe_segment(
# segment_waveform,
# sample_rate,
# language
# )
# transcript_segments.append({
# "start": round(turn.start, 2),
# "end": round(turn.end, 2),
# "speaker": speaker,
# "text": text
# })
# return self._merge_consecutive_segments(transcript_segments)
# except Exception as e:
# print(f"Diarization failed, falling back to simple transcription: {e}")
# # Fallback: simple transcription without diarization
# return self._simple_transcription(waveform, sample_rate, language)
# def _simple_transcription(self, waveform, sample_rate, language):
# """Fallback transcription without speaker diarization"""
# # Process in 30-second chunks
# chunk_length = 30 * sample_rate
# segments = []
# for i in range(0, waveform.shape[1], chunk_length):
# chunk = waveform[:, i:i + chunk_length]
# text = self._transcribe_segment(chunk, sample_rate, language)
# if text.strip():
# segments.append({
# "start": i / sample_rate,
# "end": min((i + chunk_length) / sample_rate, waveform.shape[1] / sample_rate),
# "speaker": "SPEAKER_01",
# "text": text
# })
# return segments
# def _transcribe_segment(self, waveform, sample_rate, language):
# """
# Transcribe audio segment menggunakan Whisper
# """
# # Resample if needed
# if sample_rate != 16000:
# resampler = torchaudio.transforms.Resample(sample_rate, 16000)
# waveform = resampler(waveform)
# # Prepare input
# input_features = self.whisper_processor(
# waveform.squeeze().numpy(),
# sampling_rate=16000,
# return_tensors="pt"
# ).input_features
# # Generate transcription
# forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(
# language=language,
# task="transcribe"
# )
# predicted_ids = self.whisper_model.generate(
# input_features,
# forced_decoder_ids=forced_decoder_ids,
# max_length=448
# )
# transcription = self.whisper_processor.batch_decode(
# predicted_ids,
# skip_special_tokens=True
# )[0]
# return transcription.strip()
# def _ensure_wav_format(self, audio_path):
# """
# Convert audio to WAV format if needed
# """
# if not audio_path.endswith('.wav'):
# audio = AudioSegment.from_file(audio_path)
# wav_path = tempfile.mktemp(suffix='.wav')
# audio.export(wav_path, format='wav')
# return wav_path
# return audio_path
# def _merge_consecutive_segments(self, segments):
# """
# Merge consecutive segments from same speaker
# """
# if not segments:
# return segments
# merged = [segments[0]]
# for current in segments[1:]:
# last = merged[-1]
# # Merge if same speaker and close in time
# if (last['speaker'] == current['speaker'] and
# current['start'] - last['end'] < 1.0):
# last['end'] = current['end']
# last['text'] += ' ' + current['text']
# else:
# merged.append(current)
# return merged
import torch
import torchaudio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
pipeline
)
import librosa
import numpy as np
from pydub import AudioSegment
import tempfile
import os
class SpeechProcessor:
def __init__(self):
# Load Whisper for ASR
print("Loading Whisper model...")
self.whisper_processor = WhisperProcessor.from_pretrained(
"openai/whisper-small" # Use small for HF Spaces
)
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-small"
)
# No diarization in this version
self.diarization_pipeline = None
print("Speech processor initialized (without speaker diarization)")
def process_audio(self, audio_path, language="id"):
"""
Process audio file for ASR (without speaker diarization)
"""
# Convert to WAV if needed
audio_path = self._ensure_wav_format(audio_path)
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Process audio in chunks
return self._process_audio_chunks(waveform, sample_rate, language)
def _process_audio_chunks(self, waveform, sample_rate, language):
"""Process audio in manageable chunks"""
chunk_length = 30 * sample_rate # 30-second chunks
segments = []
total_chunks = (waveform.shape[1] + chunk_length - 1) // chunk_length
for i in range(0, waveform.shape[1], chunk_length):
chunk_num = i // chunk_length + 1
print(f"Processing chunk {chunk_num}/{total_chunks}...")
chunk = waveform[:, i:i + chunk_length]
# Skip very short chunks
if chunk.shape[1] < sample_rate * 0.5:
continue
text = self._transcribe_segment(chunk, sample_rate, language)
if text.strip():
segments.append({
"start": round(i / sample_rate, 2),
"end": round(min((i + chunk_length) / sample_rate,
waveform.shape[1] / sample_rate), 2),
"speaker": "SPEAKER_01",
"text": text
})
return segments
def _transcribe_segment(self, waveform, sample_rate, language):
"""
Transcribe audio segment using Whisper
"""
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Prepare input
input_features = self.whisper_processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt"
).input_features
# Generate transcription
forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(
language=language,
task="transcribe"
)
with torch.no_grad():
predicted_ids = self.whisper_model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_length=448
)
transcription = self.whisper_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
def _ensure_wav_format(self, audio_path):
"""
Convert audio to WAV format if needed
"""
if not audio_path.endswith('.wav'):
print("Converting audio to WAV format...")
audio = AudioSegment.from_file(audio_path)
wav_path = tempfile.mktemp(suffix='.wav')
audio.export(wav_path, format='wav')
return wav_path
return audio_path