tomaszki commited on
Commit
511e949
·
1 Parent(s): 7bdd4ad

Added conversion to float32 after computation

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -29,7 +29,7 @@ def get_attention_weights_and_tokens(text):
29
  tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
30
  tokenized = tokenized.to(device)
31
  output = model(**tokenized, output_attentions=True)
32
- return output.attentions, tokens
33
 
34
  model = load_model()
35
  tokenizer = load_tokenizer()
 
29
  tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
30
  tokenized = tokenized.to(device)
31
  output = model(**tokenized, output_attentions=True)
32
+ return output.attentions.to(torch.float32), tokens
33
 
34
  model = load_model()
35
  tokenizer = load_tokenizer()