File size: 4,615 Bytes
2b49fe2
efeee8a
2b49fe2
efeee8a
 
 
2b49fe2
d82123b
dd4548e
2b49fe2
e87e116
 
efeee8a
e87e116
 
 
 
efeee8a
c6dd7aa
efeee8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6dd7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779710f
 
 
c6dd7aa
 
 
779710f
c6dd7aa
 
 
 
10ced5b
c6dd7aa
 
 
 
 
 
 
 
 
10ced5b
1567a0f
10ced5b
c6dd7aa
10ced5b
c6dd7aa
 
10ced5b
2b404c9
1567a0f
 
10ced5b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp

import torch
import torch.nn.functional as F

from transformers import AlbertTokenizer, AlbertForMaskedLM

#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM

def wide_setup():
    max_width = 1500
    padding_top = 0
    padding_right = 2
    padding_bottom = 0
    padding_left = 2

    define_margins = f"""
    <style>
        .appview-container .main .block-container{{
            max-width: {max_width}px;
            padding-top: {padding_top}rem;
            padding-right: {padding_right}rem;
            padding-left: {padding_left}rem;
            padding-bottom: {padding_bottom}rem;
        }}
    </style>
    """
    hide_table_row_index = """
                <style>
                tbody th {display:none}
                .blank {display:none}
                </style>
                """
    st.markdown(define_margins, unsafe_allow_html=True)
    st.markdown(hide_table_row_index, unsafe_allow_html=True)

@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
    tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
    model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
    return tokenizer,model

def clear_data():
    for key in st.session_state:
        del st.session_state[key]

if __name__=='__main__':
    wide_setup()

    if 'page_status' not in st.session_state:
        st.session_state['page_status'] = 'type_in'

    if st.session_state['page_status']=='type_in':
        tokenizer,model = load_model()
        mask_id = tokenizer('[MASK]').input_ids[1:-1][0]

        st.write('1. Type in the sentences and click "Tokenize"')
        sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
        sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
        if st.button('Tokenize'):
            st.session_state['page_status'] = 'tokenized'
            st.session_state['sent_1'] = sent_1
            st.session_state['sent_2'] = sent_2
            st.experimental_rerun()

    if st.session_state['page_status']=='tokenized':
        tokenizer,model = load_model()
        mask_id = tokenizer('[MASK]').input_ids[1:-1][0]

        sent_1 = st.session_state['sent_1']
        sent_2 = st.session_state['sent_2']
        if 'masked_pos_1' not in st.session_state:
            st.session_state['masked_pos_1'] = []
        if 'masked_pos_2' not in st.session_state:
            st.session_state['masked_pos_2'] = []

        st.write('2. Select sites to mask out and click "Confirm"')
        input_sent = tokenizer(sent_1).input_ids
        decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
        char_nums = [len(word)+3 for word in decoded_sent]
        st.write(char_nums)
        cols = st.columns(char_nums)
        for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
            with col:
                if st.button(word,key=f'word_{word_id}'):
                    if word_id not in st.session_state['masked_pos_1']:
                        st.session_state['masked_pos_1'].append(word_id)
                    else:
                        _ = st.session_state['masked_pos_1'].pop(word_id)
        st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')


    if st.session_state['page_status']=='analysis':
        sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
        sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
        input_ids_1 = tokenizer(sent_1).input_ids
        input_ids_2 = tokenizer(sent_2).input_ids
        input_ids = torch.tensor([input_ids_1,input_ids_2])

        outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
        logprobs = F.log_softmax(outputs['logits'], dim = -1)
        preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
        st.write([tokenizer.decode([token]) for token in preds])