import subprocess import numpy as np import requests import json from typing import Dict, List import random import torch from joblib import Parallel, delayed import os def random_runner(target_prob, size): indice = random.choices(range(0, size[1]), k=size[0]) value = target_prob[range(len(indice)), indice].sum().detach().numpy().item() return indice, value def query(data, model_id, api_token) -> Dict: """ Helper function to query text from audio file by huggingface api inference. """ headers = {"Authorization": f"Bearer {api_token}"} api_url = f"https://api-inference.huggingface.co/models/{model_id}" response = requests.request("POST", api_url, headers=headers, data=data) return json.loads(response.content.decode("utf-8")) def query_process(filename, model_id, api_token) -> Dict: """ Helper function to query text from audio file by huggingface api inference. """ headers = {"Authorization": f"Bearer {api_token}"} api_url = f"https://api-inference.huggingface.co/models/{model_id}" with open(filename, "rb") as f: data = f.read() response = requests.request("POST", api_url, headers=headers, data=data) return json.loads(response.content.decode("utf-8")) def query_dummy(raw_data, processor, model): inputs = processor(raw_data, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) return transcription[0] def query_raw(raw_data, word, processor, processor_with_lm, model, temperature=15) -> List: """ Helper function to query draw file to huggingface api inference. """ input_values = processor(raw_data, sampling_rate=16000, return_tensors="pt").input_values with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) top1_prediction = processor_with_lm.decode(logits[0].cpu().numpy())['text'] if word != top1_prediction.replace(" ", ""): pad_token_id = processor.tokenizer.pad_token_id word_delimiter_token_id = processor.tokenizer.word_delimiter_token_id value_top5, ind_top5 = torch.topk(logits, 3) target_index = ind_top5[(predicted_ids != word_delimiter_token_id) & (predicted_ids != pad_token_id)] target_prob = value_top5[(predicted_ids != word_delimiter_token_id) & (predicted_ids != pad_token_id)] size = target_index.size() trial = size[1]**4//2 prediction_list = Parallel(n_jobs=1, backend="multiprocessing")( delayed(random_runner)(target_prob, size) for _ in range(trial) ) target_dict = {i[1]: i[0] for i in prediction_list} target_dict = sorted(target_dict.items(), reverse=True) results = {} for top_pred in target_dict[:temperature]: indices = top_pred[1] output_sentence = processor.decode(target_index[range(size[0]), indices]).lower() results[output_sentence] = top_pred[0] results = sorted(results.items(), key=lambda x: x[1], reverse=True) return results else: return [(word, 100)] def find_different(target, prediction): # target_word = set(target) # prediction_word = set(prediction) # difference = target_word.symmetric_difference(prediction_word) # wrong_words = [word for word in target_word if word in list(difference)] if len(target) != len(prediction): target = target[:len(prediction)] wrong_words = [str(1) if target[index] != prediction[index] else str(0) for index in range(len(target))] return "".join(wrong_words) def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: """ Helper function to read an audio file through ffmpeg. """ ar = f"{sampling_rate}" ac = "1" format_for_conversion = "f32le" ffmpeg_command = [ "ffmpeg", "-i", "pipe:0", "-ac", ac, "-ar", ar, "-f", format_for_conversion, "-hide_banner", "-loglevel", "quiet", "pipe:1", ] try: ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) except FileNotFoundError: raise ValueError("ffmpeg was not found but is required to load audio files from filename") output_stream = ffmpeg_process.communicate(bpayload) out_bytes = output_stream[0] audio = np.frombuffer(out_bytes, np.float32) # if audio.shape[0] == 0: # raise ValueError("Malformed soundfile") return audio def get_model_size(model): torch.save(model.state_dict(), 'temp_saved_model.pt') model_size_in_mb = os.path.getsize('temp_saved_model.pt') >> 20 os.remove('temp_saved_model.pt') return model_size_in_mb