File size: 2,084 Bytes
6a91da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torchaudio
from pathlib import Path
import soundfile as sf
from typing import Any


from config import TARGET_SR, SUPPORTED_EXTS


def transcribe_file(path: str | Path, pipe: Any) -> str:
    """
    Transcribe an audio file to text using a given ASR pipeline.

    Args:
        path: Path or string pointing to an audio file.
        asr_pipeline: A Hugging Face transformers pipeline object for
                      automatic-speech-recognition. Should accept a numpy
                      array and return a dict with key 'text'.

    Returns:
        The transcribed text as returned by the pipeline.

    Raises:
        ValueError: If loading or decoding the audio fails.
    """
    speech = load_resample(path)
    return pipe(speech.numpy())["text"]  # type: ignore[index]


def load_resample(path: str | Path, target_sr: int = TARGET_SR) -> torch.Tensor:
    """
    Load an audio file and resample it to the target sample rate, returning
    a mono torch.Tensor.

    Args:
        path: Path or string pointing to an audio file.
        target_sr: Desired sample rate (in Hz). Defaults to TARGET_SR from config.

    Returns:
        A 1-D torch.Tensor of dtype float32 sampled at target_sr.

    Raises:
        ValueError: If the file extension is not in SUPPORTED_EXTS.
        ValueError: If the audio file cannot be decoded.
    """
    ext = Path(path).suffix.lower()
    if ext not in SUPPORTED_EXTS:
        raise ValueError(
            f"Unsupported file-type β€œ{ext or 'unknown'}”. Please upload WAV, FLAC, MP3, OGG/Opus or M4A."
        )

    try:
        speech, sr = sf.read(str(path))
    except RuntimeError as exc:
        raise ValueError(
            "Couldn't decode the audio file - maybe it's corrupted or in an uncommon codec."
        ) from exc

    speech = torch.tensor(speech).float()
    if speech.ndim == 2:  # stereo to mono
        speech = speech.mean(dim=1)
    if sr != target_sr:
        speech = torchaudio.functional.resample(
            speech, orig_freq=sr, new_freq=target_sr
        )
    return speech