vumichien commited on
Commit
532a2ea
·
1 Parent(s): aba3792

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +130 -0
utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import numpy as np
3
+ import requests
4
+ import json
5
+ from typing import Dict, List
6
+ import random
7
+ import torch
8
+ from joblib import Parallel, delayed
9
+ import os
10
+
11
+
12
+ def random_runner(target_prob, size):
13
+ indice = random.choices(range(0, size[1]), k=size[0])
14
+ value = target_prob[range(len(indice)), indice].sum().detach().numpy().item()
15
+ return indice, value
16
+
17
+
18
+ def query(data, model_id, api_token) -> Dict:
19
+ """
20
+ Helper function to query text from audio file by huggingface api inference.
21
+ """
22
+ headers = {"Authorization": f"Bearer {api_token}"}
23
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
24
+ response = requests.request("POST", api_url, headers=headers, data=data)
25
+ return json.loads(response.content.decode("utf-8"))
26
+
27
+
28
+ def query_process(filename, model_id, api_token) -> Dict:
29
+ """
30
+ Helper function to query text from audio file by huggingface api inference.
31
+ """
32
+ headers = {"Authorization": f"Bearer {api_token}"}
33
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
34
+ with open(filename, "rb") as f:
35
+ data = f.read()
36
+ response = requests.request("POST", api_url, headers=headers, data=data)
37
+ return json.loads(response.content.decode("utf-8"))
38
+
39
+ def query_dummy(raw_data, processor, model):
40
+ inputs = processor(raw_data, sampling_rate=16000, return_tensors="pt")
41
+ with torch.no_grad():
42
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
43
+ predicted_ids = torch.argmax(logits, dim=-1)
44
+ transcription = processor.batch_decode(predicted_ids)
45
+ return transcription[0]
46
+
47
+ def query_raw(raw_data, word, processor, processor_with_lm, model, temperature=15) -> List:
48
+ """
49
+ Helper function to query draw file to huggingface api inference.
50
+ """
51
+ input_values = processor(raw_data, sampling_rate=16000, return_tensors="pt").input_values
52
+ with torch.no_grad():
53
+ logits = model(input_values).logits
54
+ predicted_ids = torch.argmax(logits, dim=-1)
55
+ top1_prediction = processor_with_lm.decode(logits[0].cpu().numpy())['text']
56
+ if word != top1_prediction.replace(" ", ""):
57
+ pad_token_id = processor.tokenizer.pad_token_id
58
+ word_delimiter_token_id = processor.tokenizer.word_delimiter_token_id
59
+ value_top5, ind_top5 = torch.topk(logits, 3)
60
+ target_index = ind_top5[(predicted_ids != word_delimiter_token_id) & (predicted_ids != pad_token_id)]
61
+ target_prob = value_top5[(predicted_ids != word_delimiter_token_id) & (predicted_ids != pad_token_id)]
62
+ size = target_index.size()
63
+ trial = size[1]**4//2
64
+ prediction_list = Parallel(n_jobs=1, backend="multiprocessing")(
65
+ delayed(random_runner)(target_prob, size) for _ in range(trial)
66
+ )
67
+ target_dict = {i[1]: i[0] for i in prediction_list}
68
+ target_dict = sorted(target_dict.items(), reverse=True)
69
+ results = {}
70
+ for top_pred in target_dict[:temperature]:
71
+ indices = top_pred[1]
72
+ output_sentence = processor.decode(target_index[range(size[0]), indices]).lower()
73
+ results[output_sentence] = top_pred[0]
74
+ results = sorted(results.items(), key=lambda x: x[1], reverse=True)
75
+ return results
76
+ else:
77
+ return [(word, 100)]
78
+
79
+
80
+ def find_different(target, prediction):
81
+ # target_word = set(target)
82
+ # prediction_word = set(prediction)
83
+ # difference = target_word.symmetric_difference(prediction_word)
84
+ # wrong_words = [word for word in target_word if word in list(difference)]
85
+ if len(target) != len(prediction):
86
+ target = target[:len(prediction)]
87
+ wrong_words = [str(1) if target[index] != prediction[index] else str(0) for index in range(len(target))]
88
+ return "".join(wrong_words)
89
+
90
+
91
+ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
92
+ """
93
+ Helper function to read an audio file through ffmpeg.
94
+ """
95
+ ar = f"{sampling_rate}"
96
+ ac = "1"
97
+ format_for_conversion = "f32le"
98
+ ffmpeg_command = [
99
+ "ffmpeg",
100
+ "-i",
101
+ "pipe:0",
102
+ "-ac",
103
+ ac,
104
+ "-ar",
105
+ ar,
106
+ "-f",
107
+ format_for_conversion,
108
+ "-hide_banner",
109
+ "-loglevel",
110
+ "quiet",
111
+ "pipe:1",
112
+ ]
113
+
114
+ try:
115
+ ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
116
+ except FileNotFoundError:
117
+ raise ValueError("ffmpeg was not found but is required to load audio files from filename")
118
+ output_stream = ffmpeg_process.communicate(bpayload)
119
+ out_bytes = output_stream[0]
120
+ audio = np.frombuffer(out_bytes, np.float32)
121
+ # if audio.shape[0] == 0:
122
+ # raise ValueError("Malformed soundfile")
123
+ return audio
124
+
125
+
126
+ def get_model_size(model):
127
+ torch.save(model.state_dict(), 'temp_saved_model.pt')
128
+ model_size_in_mb = os.path.getsize('temp_saved_model.pt') >> 20
129
+ os.remove('temp_saved_model.pt')
130
+ return model_size_in_mb