Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
c5489ad
1
Parent(s):
1567a0f
test
Browse files
app.py
CHANGED
@@ -43,6 +43,10 @@ def wide_setup():
|
|
43 |
st.markdown(define_margins, unsafe_allow_html=True)
|
44 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
45 |
|
|
|
|
|
|
|
|
|
46 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
47 |
def load_model():
|
48 |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
@@ -56,48 +60,49 @@ def clear_data():
|
|
56 |
|
57 |
if __name__=='__main__':
|
58 |
wide_setup()
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
if 'page_status' not in st.session_state:
|
61 |
st.session_state['page_status'] = 'type_in'
|
62 |
|
63 |
if st.session_state['page_status']=='type_in':
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
st.session_state['sent_2'] = sent_2
|
74 |
-
st.experimental_rerun()
|
75 |
|
76 |
if st.session_state['page_status']=='tokenized':
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
st.
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
_ = st.session_state['masked_pos_1'].pop(word_id)
|
100 |
-
st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
|
101 |
|
102 |
|
103 |
if st.session_state['page_status']=='analysis':
|
|
|
43 |
st.markdown(define_margins, unsafe_allow_html=True)
|
44 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
45 |
|
46 |
+
def local_css(file_name):
|
47 |
+
with open(file_name) as f:
|
48 |
+
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
49 |
+
|
50 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
51 |
def load_model():
|
52 |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
|
|
60 |
|
61 |
if __name__=='__main__':
|
62 |
wide_setup()
|
63 |
+
load_css('style.css')
|
64 |
+
tokenizer,model = load_model()
|
65 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
66 |
+
|
67 |
+
main_area = st.empty()
|
68 |
|
69 |
if 'page_status' not in st.session_state:
|
70 |
st.session_state['page_status'] = 'type_in'
|
71 |
|
72 |
if st.session_state['page_status']=='type_in':
|
73 |
+
with main_area.container():
|
74 |
+
st.write('1. Type in the sentences and click "Tokenize"')
|
75 |
+
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
|
76 |
+
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
|
77 |
+
if st.button('Tokenize'):
|
78 |
+
st.session_state['page_status'] = 'tokenized'
|
79 |
+
st.session_state['sent_1'] = sent_1
|
80 |
+
st.session_state['sent_2'] = sent_2
|
81 |
+
main_area.empty()
|
|
|
|
|
82 |
|
83 |
if st.session_state['page_status']=='tokenized':
|
84 |
+
with main_area.container():
|
85 |
+
sent_1 = st.session_state['sent_1']
|
86 |
+
sent_2 = st.session_state['sent_2']
|
87 |
+
if 'masked_pos_1' not in st.session_state:
|
88 |
+
st.session_state['masked_pos_1'] = []
|
89 |
+
if 'masked_pos_2' not in st.session_state:
|
90 |
+
st.session_state['masked_pos_2'] = []
|
91 |
+
|
92 |
+
st.write('2. Select sites to mask out and click "Confirm"')
|
93 |
+
input_sent = tokenizer(sent_1).input_ids
|
94 |
+
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
|
95 |
+
char_nums = [len(word)+3 for word in decoded_sent]
|
96 |
+
st.write(char_nums)
|
97 |
+
cols = st.columns(char_nums)
|
98 |
+
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
|
99 |
+
with col:
|
100 |
+
if st.button(word,key=f'word_{word_id}'):
|
101 |
+
if word_id not in st.session_state['masked_pos_1']:
|
102 |
+
st.session_state['masked_pos_1'].append(word_id)
|
103 |
+
else:
|
104 |
+
st.session_state['masked_pos_1'].remove(word_id)
|
105 |
+
st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
|
|
|
|
|
106 |
|
107 |
|
108 |
if st.session_state['page_status']=='analysis':
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.stButton>button {
|
2 |
+
font-size: 8px;
|
3 |
+
}
|