taka-yamakoshi commited on
Commit
c5489ad
·
1 Parent(s): 1567a0f
Files changed (2) hide show
  1. app.py +40 -35
  2. style.css +3 -0
app.py CHANGED
@@ -43,6 +43,10 @@ def wide_setup():
43
  st.markdown(define_margins, unsafe_allow_html=True)
44
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
45
 
 
 
 
 
46
  @st.cache(show_spinner=True,allow_output_mutation=True)
47
  def load_model():
48
  tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
@@ -56,48 +60,49 @@ def clear_data():
56
 
57
  if __name__=='__main__':
58
  wide_setup()
 
 
 
 
 
59
 
60
  if 'page_status' not in st.session_state:
61
  st.session_state['page_status'] = 'type_in'
62
 
63
  if st.session_state['page_status']=='type_in':
64
- tokenizer,model = load_model()
65
- mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
66
-
67
- st.write('1. Type in the sentences and click "Tokenize"')
68
- sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
69
- sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
70
- if st.button('Tokenize'):
71
- st.session_state['page_status'] = 'tokenized'
72
- st.session_state['sent_1'] = sent_1
73
- st.session_state['sent_2'] = sent_2
74
- st.experimental_rerun()
75
 
76
  if st.session_state['page_status']=='tokenized':
77
- tokenizer,model = load_model()
78
- mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
79
-
80
- sent_1 = st.session_state['sent_1']
81
- sent_2 = st.session_state['sent_2']
82
- if 'masked_pos_1' not in st.session_state:
83
- st.session_state['masked_pos_1'] = []
84
- if 'masked_pos_2' not in st.session_state:
85
- st.session_state['masked_pos_2'] = []
86
-
87
- st.write('2. Select sites to mask out and click "Confirm"')
88
- input_sent = tokenizer(sent_1).input_ids
89
- decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
90
- char_nums = [len(word)+3 for word in decoded_sent]
91
- st.write(char_nums)
92
- cols = st.columns(char_nums)
93
- for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
94
- with col:
95
- if st.button(word,key=f'word_{word_id}'):
96
- if word_id not in st.session_state['masked_pos_1']:
97
- st.session_state['masked_pos_1'].append(word_id)
98
- else:
99
- _ = st.session_state['masked_pos_1'].pop(word_id)
100
- st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
101
 
102
 
103
  if st.session_state['page_status']=='analysis':
 
43
  st.markdown(define_margins, unsafe_allow_html=True)
44
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
45
 
46
+ def local_css(file_name):
47
+ with open(file_name) as f:
48
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
49
+
50
  @st.cache(show_spinner=True,allow_output_mutation=True)
51
  def load_model():
52
  tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
 
60
 
61
  if __name__=='__main__':
62
  wide_setup()
63
+ load_css('style.css')
64
+ tokenizer,model = load_model()
65
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
66
+
67
+ main_area = st.empty()
68
 
69
  if 'page_status' not in st.session_state:
70
  st.session_state['page_status'] = 'type_in'
71
 
72
  if st.session_state['page_status']=='type_in':
73
+ with main_area.container():
74
+ st.write('1. Type in the sentences and click "Tokenize"')
75
+ sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
76
+ sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
77
+ if st.button('Tokenize'):
78
+ st.session_state['page_status'] = 'tokenized'
79
+ st.session_state['sent_1'] = sent_1
80
+ st.session_state['sent_2'] = sent_2
81
+ main_area.empty()
 
 
82
 
83
  if st.session_state['page_status']=='tokenized':
84
+ with main_area.container():
85
+ sent_1 = st.session_state['sent_1']
86
+ sent_2 = st.session_state['sent_2']
87
+ if 'masked_pos_1' not in st.session_state:
88
+ st.session_state['masked_pos_1'] = []
89
+ if 'masked_pos_2' not in st.session_state:
90
+ st.session_state['masked_pos_2'] = []
91
+
92
+ st.write('2. Select sites to mask out and click "Confirm"')
93
+ input_sent = tokenizer(sent_1).input_ids
94
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
95
+ char_nums = [len(word)+3 for word in decoded_sent]
96
+ st.write(char_nums)
97
+ cols = st.columns(char_nums)
98
+ for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
99
+ with col:
100
+ if st.button(word,key=f'word_{word_id}'):
101
+ if word_id not in st.session_state['masked_pos_1']:
102
+ st.session_state['masked_pos_1'].append(word_id)
103
+ else:
104
+ st.session_state['masked_pos_1'].remove(word_id)
105
+ st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
 
 
106
 
107
 
108
  if st.session_state['page_status']=='analysis':
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .stButton>button {
2
+ font-size: 8px;
3
+ }