File size: 5,849 Bytes
d0584cc
87aa741
d0584cc
87aa741
 
d0584cc
87aa741
 
d0584cc
87aa741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0584cc
87aa741
b1939df
 
87aa741
b1939df
87aa741
d0584cc
87aa741
 
 
 
d0584cc
87aa741
 
 
 
 
 
 
 
 
 
 
d0584cc
87aa741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0584cc
 
 
87aa741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0584cc
 
87aa741
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from transformers import pipeline
import librosa  # Or soundfile
import os
from smolagents.tools import Tool  # Added import
from typing import Optional  # Added for type hinting

# Initialize the ASR pipeline once
_asr_pipeline_instance = None


def get_asr_pipeline():
    global _asr_pipeline_instance
    if _asr_pipeline_instance is None:
        try:
            # Using a smaller Whisper model for quicker setup, but larger models offer better accuracy
            _asr_pipeline_instance = pipeline(
                "automatic-speech-recognition",
                model="openai/whisper-tiny.en",  # Consider making model configurable
            )
            print("ASR pipeline initialized.")  # For feedback
        except Exception as e:
            print(f"Error initializing ASR pipeline: {e}")
            # Handle error appropriately, e.g., raise or log
    return _asr_pipeline_instance


# Original transcription function, renamed to be internal
def _transcribe_audio_file(audio_filepath: str, asr_pipeline_instance) -> str:
    """
    Converts speech in an audio file to text using the provided ASR pipeline.
    Args:
        audio_filepath (str): Path to the audio file.
        asr_pipeline_instance: The initialized ASR pipeline.
    Returns:
        str: Transcribed text from the audio or an error message.
    """
    if not asr_pipeline_instance:
        return "Error: ASR pipeline is not available."
    if not os.path.exists(audio_filepath):
        return f"Error: Audio file not found at {audio_filepath}"
    try:
        # Ensure the file can be loaded by librosa (or your chosen audio library)
        # This step can help catch corrupted or unsupported audio formats early.
        y, sr = librosa.load(audio_filepath, sr=None)  # Load with original sample rate
        if sr != 16000:  # Whisper models expect 16kHz
            y = librosa.resample(y, orig_sr=sr, target_sr=16000)

        # Pass the numpy array to the pipeline
        transcription_result = asr_pipeline_instance(
            {"raw": y, "sampling_rate": 16000}, return_timestamps=False
        )  # Changed to False for simplicity
        return transcription_result["text"]
    except Exception as e:
        return f"Error during transcription of {audio_filepath}: {e}"


class SpeechToTextTool(Tool):
    """
    Transcribes audio from a given audio file path to text.
    """

    name = "speech_to_text_transcriber"
    description = "Converts speech in an audio file (e.g., .mp3, .wav) to text using speech recognition."
    inputs = {
        "audio_filepath": {"type": "string", "description": "Path to the audio file to transcribe."}
    }
    outputs = {
        "transcribed_text": {
            "type": "string",
            "description": "The transcribed text from the audio, or an error message.",
        }
    }
    output_type = "string"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.asr_pipeline = get_asr_pipeline()  # Initialize or get the shared pipeline
        self.is_initialized = True if self.asr_pipeline else False

    def forward(self, audio_filepath: str) -> str:
        """
        Wrapper for the _transcribe_audio_file function.
        """
        if not self.is_initialized or not self.asr_pipeline:
            return "Error: SpeechToTextTool was not initialized properly (ASR pipeline missing)."
        return _transcribe_audio_file(audio_filepath, self.asr_pipeline)


# Expose the original function name if needed by other parts of the system (optional)
# transcribe_audio = _transcribe_audio_file # This would need adjustment if it expects the pipeline passed in

# Example usage:
if __name__ == "__main__":
    tool_instance = SpeechToTextTool()

    # Create a dummy MP3 file for testing (requires ffmpeg to be installed for pydub to work)
    # This part is tricky to make universally runnable without external dependencies for audio creation.
    # For a simple test, we'll assume a file exists or skip this part if it doesn't.

    # Path to a test audio file (replace with an actual .mp3 or .wav file for testing)
    # You might need to download a short sample audio file and place it in your project.
    # e.g., create a `test_data` directory and put `sample.mp3` there.
    test_audio_file = "./data/downloaded_files/1f975693-876d-457b-a649-393859e79bf3.mp3"  # GAIA example
    # test_audio_file_2 = "./data/downloaded_files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3" # GAIA example

    if tool_instance.is_initialized:
        if os.path.exists(test_audio_file):
            print(f"Attempting to transcribe: {test_audio_file}")
            transcribed_text = tool_instance.forward(test_audio_file)
            print(f"Transcription:\n{transcribed_text}")
        else:
            print(
                f"Test audio file not found: {test_audio_file}. Skipping transcription test."
            )
            print("Please place a sample .mp3 or .wav file at that location for testing.")

        # if os.path.exists(test_audio_file_2):
        #     print(f"\nAttempting to transcribe: {test_audio_file_2}")
        #     transcribed_text_2 = tool_instance.forward(test_audio_file_2)
        #     print(f"Transcription 2:\n{transcribed_text_2}")
        # else:
        #     print(f"Test audio file 2 not found: {test_audio_file_2}. Skipping.")

    else:
        print(
            "SpeechToTextTool could not be initialized (ASR pipeline missing). Transcription test skipped."
        )

    # Test with a non-existent file
    non_existent_file = "./non_existent_audio.mp3"
    print(f"\nAttempting to transcribe non-existent file: {non_existent_file}")
    error_text = tool_instance.forward(non_existent_file)
    print(f"Result for non-existent file:\n{error_text}")
    assert "Error:" in error_text  # Expect an error message