Spaces:
mskov
/
Runtime error

mskov commited on
Commit
306f4a4
·
1 Parent(s): a564048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py CHANGED
@@ -9,6 +9,40 @@ from transformers import AutoModelForCausalLM
9
  from transformers import AutoTokenizer
10
  # from next_word_prediction import GPT2
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from share_btn import community_icon_html, loading_icon_html, share_js
13
 
14
  # get gpt2 model
 
9
  from transformers import AutoTokenizer
10
  # from next_word_prediction import GPT2
11
 
12
+ ### code
13
+ gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
14
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
+
16
+ input_ids = tokenizer("Today is a nice day", return_tensors="pt").input_ids
17
+
18
+ generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
19
+
20
+ # only use id's that were generated
21
+ # gen_sequences has shape [3, 15]
22
+ gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
23
+
24
+ # let's stack the logits generated at each step to a tensor and transform
25
+ # logits to probs
26
+ probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
27
+
28
+ # now we need to collect the probability of the generated token
29
+ # we need to add a dummy dim in the end to make gather work
30
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
31
+
32
+ # now we can do all kinds of things with the probs
33
+
34
+ # 1) the probs that exactly those sequences are generated again
35
+ # those are normally going to be very small
36
+ unique_prob_per_sequence = gen_probs.prod(-1)
37
+
38
+ # 2) normalize the probs over the three sequences
39
+ normed_gen_probs = gen_probs / gen_probs.sum(0)
40
+ assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
41
+
42
+ # 3) compare normalized probs to each other like in 1)
43
+ unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
44
+
45
+ ### end code
46
  from share_btn import community_icon_html, loading_icon_html, share_js
47
 
48
  # get gpt2 model