File size: 646 Bytes
82a7fca
7b2c089
 
 
 
 
 
82a7fca
 
7b2c089
 
 
 
 
 
 
 
 
 
 
 
 
41479fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# import torch_directml
from transformers import pipeline

MODEL_CHECKPOINT = "openai/whisper-small"
CHUNK_LENGTH_S = 30

def get_device():
    return "cpu"
    # return torch_directml.device()

def get_pipe(device, model_checkpoint=MODEL_CHECKPOINT, chunk_length_s=CHUNK_LENGTH_S):
    return pipeline(
        "automatic-speech-recognition",
        model=model_checkpoint,
        chunk_length_s=chunk_length_s,
        device=device,
    )

def get_prediction_with_timelines(pipe, sample):
    return pipe(sample, batch_size=8, return_timestamps=True)["chunks"]

def get_prediction(pipe, sample):
    return pipe(sample, batch_size=8)["text"]