Simon Salmon commited on
Commit
0191d40
·
1 Parent(s): c838cfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -16,6 +16,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
16
  artist_name = st.text_input("Model", "roberta-large")
17
  tokenizer = AutoTokenizer.from_pretrained("roberta-large")
18
  model = AutoModelForMaskedLM.from_pretrained(artist_name)
 
19
 
20
 
21
 
@@ -36,4 +37,13 @@ with st.form(key='my_form'):
36
  mask_hidden_state = last_hidden_state[mask_index]
37
  idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
38
  words = [tokenizer.decode(i.item()).strip() for i in idx]
39
- st.text_area(label = 'Infill:', value=words)
 
 
 
 
 
 
 
 
 
 
16
  artist_name = st.text_input("Model", "roberta-large")
17
  tokenizer = AutoTokenizer.from_pretrained("roberta-large")
18
  model = AutoModelForMaskedLM.from_pretrained(artist_name)
19
+ model2 = AutoModelForMaskedLM.from_pretrained("BigSalmon/FormalRobertaa")
20
 
21
 
22
 
 
37
  mask_hidden_state = last_hidden_state[mask_index]
38
  idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
39
  words = [tokenizer.decode(i.item()).strip() for i in idx]
40
+ a_list.append(words)
41
+ with torch.no_grad():
42
+ output = model2(token_ids)
43
+ last_hidden_state = output[0].squeeze()
44
+ for mask_index in masked_pos:
45
+ mask_hidden_state = last_hidden_state[mask_index]
46
+ idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]
47
+ words2 = [tokenizer.decode(i.item()).strip() for i in idx]
48
+ a_list.append(words2)
49
+ st.text_area(label = 'Infill:', value=a_list)