Spaces:
mskov
/
Runtime error

mskov commited on
Commit
ec1d635
·
1 Parent(s): 306f4a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -32
app.py CHANGED
@@ -9,40 +9,12 @@ from transformers import AutoModelForCausalLM
9
  from transformers import AutoTokenizer
10
  # from next_word_prediction import GPT2
11
 
12
- ### code
13
  gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
 
16
- input_ids = tokenizer("Today is a nice day", return_tensors="pt").input_ids
17
 
18
- generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
19
-
20
- # only use id's that were generated
21
- # gen_sequences has shape [3, 15]
22
- gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
23
-
24
- # let's stack the logits generated at each step to a tensor and transform
25
- # logits to probs
26
- probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
27
-
28
- # now we need to collect the probability of the generated token
29
- # we need to add a dummy dim in the end to make gather work
30
- gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
31
-
32
- # now we can do all kinds of things with the probs
33
-
34
- # 1) the probs that exactly those sequences are generated again
35
- # those are normally going to be very small
36
- unique_prob_per_sequence = gen_probs.prod(-1)
37
-
38
- # 2) normalize the probs over the three sequences
39
- normed_gen_probs = gen_probs / gen_probs.sum(0)
40
- assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
41
-
42
- # 3) compare normalized probs to each other like in 1)
43
- unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
44
-
45
- ### end code
46
  from share_btn import community_icon_html, loading_icon_html, share_js
47
 
48
  # get gpt2 model
@@ -66,14 +38,45 @@ def inference(audio):
66
  _, probs = model.detect_language(mel)
67
 
68
  # decode audio data
69
- options = whisper.DecodingOptions(fp16 = False)
70
  # transcribe speech to text
71
  result = whisper.decode(model, mel, options)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # print audio data as text
74
  # print(result.text)
75
  getText = generator(result.text, max_length=3, num_return_sequences=5)
76
- pprint(getText)
77
  return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
78
 
79
 
 
9
  from transformers import AutoTokenizer
10
  # from next_word_prediction import GPT2
11
 
12
+ ### code snippet
13
  gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
 
16
+ ### /code snippet
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from share_btn import community_icon_html, loading_icon_html, share_js
19
 
20
  # get gpt2 model
 
38
  _, probs = model.detect_language(mel)
39
 
40
  # decode audio data
41
+ options = whisper.DecodingOptions(fp16 = True)
42
  # transcribe speech to text
43
  result = whisper.decode(model, mel, options)
44
 
45
+ ### code
46
+ input_ids = tokenizer(result, return_tensors="pt").input_ids
47
+
48
+ generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
49
+
50
+ # only use id's that were generated
51
+ # gen_sequences has shape [3, 15]
52
+ gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
53
+
54
+ # let's stack the logits generated at each step to a tensor and transform
55
+ # logits to probs
56
+ probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
57
+
58
+ # now we need to collect the probability of the generated token
59
+ # we need to add a dummy dim in the end to make gather work
60
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
61
+
62
+ # now we can do all kinds of things with the probs
63
+
64
+ # 1) the probs that exactly those sequences are generated again
65
+ # those are normally going to be very small
66
+ # unique_prob_per_sequence = gen_probs.prod(-1)
67
+
68
+ # 2) normalize the probs over the three sequences
69
+ # normed_gen_probs = gen_probs / gen_probs.sum(0)
70
+ # assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
71
+
72
+ # 3) compare normalized probs to each other like in 1)
73
+ # unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
74
+
75
+ ### end code
76
  # print audio data as text
77
  # print(result.text)
78
  getText = generator(result.text, max_length=3, num_return_sequences=5)
79
+ pprint(getText, gen_probs)
80
  return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
81
 
82