ricardo-lsantos's picture
Cleanup commented code
41479fc
raw
history blame contribute delete
646 Bytes
# import torch_directml
from transformers import pipeline
MODEL_CHECKPOINT = "openai/whisper-small"
CHUNK_LENGTH_S = 30
def get_device():
return "cpu"
# return torch_directml.device()
def get_pipe(device, model_checkpoint=MODEL_CHECKPOINT, chunk_length_s=CHUNK_LENGTH_S):
return pipeline(
"automatic-speech-recognition",
model=model_checkpoint,
chunk_length_s=chunk_length_s,
device=device,
)
def get_prediction_with_timelines(pipe, sample):
return pipe(sample, batch_size=8, return_timestamps=True)["chunks"]
def get_prediction(pipe, sample):
return pipe(sample, batch_size=8)["text"]