Spaces:
Runtime error
Runtime error
File size: 4,615 Bytes
2b49fe2 efeee8a 2b49fe2 efeee8a 2b49fe2 d82123b dd4548e 2b49fe2 e87e116 efeee8a e87e116 efeee8a c6dd7aa efeee8a c6dd7aa 779710f c6dd7aa 779710f c6dd7aa 10ced5b c6dd7aa 10ced5b 1567a0f 10ced5b c6dd7aa 10ced5b c6dd7aa 10ced5b 2b404c9 1567a0f 10ced5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer, AlbertForMaskedLM
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
def wide_setup():
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
return tokenizer,model
def clear_data():
for key in st.session_state:
del st.session_state[key]
if __name__=='__main__':
wide_setup()
if 'page_status' not in st.session_state:
st.session_state['page_status'] = 'type_in'
if st.session_state['page_status']=='type_in':
tokenizer,model = load_model()
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
st.write('1. Type in the sentences and click "Tokenize"')
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
if st.button('Tokenize'):
st.session_state['page_status'] = 'tokenized'
st.session_state['sent_1'] = sent_1
st.session_state['sent_2'] = sent_2
st.experimental_rerun()
if st.session_state['page_status']=='tokenized':
tokenizer,model = load_model()
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
if 'masked_pos_1' not in st.session_state:
st.session_state['masked_pos_1'] = []
if 'masked_pos_2' not in st.session_state:
st.session_state['masked_pos_2'] = []
st.write('2. Select sites to mask out and click "Confirm"')
input_sent = tokenizer(sent_1).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+3 for word in decoded_sent]
st.write(char_nums)
cols = st.columns(char_nums)
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_{word_id}'):
if word_id not in st.session_state['masked_pos_1']:
st.session_state['masked_pos_1'].append(word_id)
else:
_ = st.session_state['masked_pos_1'].pop(word_id)
st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
if st.session_state['page_status']=='analysis':
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
input_ids_1 = tokenizer(sent_1).input_ids
input_ids_2 = tokenizer(sent_2).input_ids
input_ids = torch.tensor([input_ids_1,input_ids_2])
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
logprobs = F.log_softmax(outputs['logits'], dim = -1)
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
st.write([tokenizer.decode([token]) for token in preds])
|