File size: 5,539 Bytes
5e5793b
 
 
 
 
 
 
 
 
8a204f8
 
 
 
 
 
 
 
 
 
 
5e5793b
 
 
8a204f8
 
 
21c2f11
8a204f8
21c2f11
 
 
 
 
 
8a204f8
 
 
 
 
 
 
a999c8e
 
 
 
8a204f8
 
 
 
21c2f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e5793b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a999c8e
21c2f11
8a204f8
 
a999c8e
 
 
 
 
8a204f8
5e5793b
a999c8e
21c2f11
8a204f8
 
 
 
 
 
 
 
21c2f11
 
 
 
5e5793b
8a204f8
 
 
 
 
 
228552f
8a204f8
 
21c2f11
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pandas as pd
import streamlit as st
import numpy as np
import torch
import io
import time

@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model(model_name):
    if model_name.startswith('bert'):
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained(model_name)
    elif model_name.startswith('gpt2'):
        from transformers import GPT2Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    elif model_name.startswith('roberta'):
        from transformers import RobertaTokenizer
        tokenizer = RobertaTokenizer.from_pretrained(model_name)
    elif model_name.startswith('albert'):
        from transformers import AlbertTokenizer
        tokenizer = AlbertTokenizer.from_pretrained(model_name)
    return tokenizer

def generate_markdown(text,color='black',font='Arial',size=20):
    return f"<p style='text-align:center; color:{color}; font-family:{font}; font-size:{size}px;'>{text}</p>"

def TokenizeText(sentence,tokenizer_name):
    if len(sentence)>0:
        if tokenizer_name.startswith('gpt2'):
            input_sent = tokenizer(sentence)['input_ids']
        else:
            input_sent = tokenizer(sentence)['input_ids'][1:-1]
        encoded_sent = [str(token) for token in input_sent]
        decoded_sent = [tokenizer.decode([token]) for token in input_sent]
        num_tokens = len(decoded_sent)

        #char_nums = [len(word)+2 for word in decoded_sent]
        #word_cols = st.columns(char_nums)
        #for word_col,word in zip(word_cols,decoded_sent):
            #with word_col:
                #st.write(word)
        #st.write('   '.join(encoded_sent))
        #st.write('   '.join(decoded_sent))
        st.markdown(generate_markdown('   '.join(encoded_sent),size=16), unsafe_allow_html=True)
        st.markdown(generate_markdown('   '.join(decoded_sent),size=16), unsafe_allow_html=True)
        st.markdown(generate_markdown(f'{num_tokens} tokens'), unsafe_allow_html=True)

        return num_tokens

def DeTokenizeText(input_str):
    if len(input_str)>0:
        input_sent = [int(element) for element in input_str.strip().split(' ')]
        encoded_sent = [str(token) for token in input_sent]
        decoded_sent = [tokenizer.decode([token]) for token in input_sent]
        num_tokens = len(decoded_sent)

        #char_nums = [len(word)+2 for word in decoded_sent]
        #word_cols = st.columns(char_nums)
        #for word_col,word in zip(word_cols,decoded_sent):
            #with word_col:
                #st.write(word)
        #st.write('   '.join(encoded_sent))
        #st.write('   '.join(decoded_sent))
        st.markdown(generate_markdown('   '.join(decoded_sent)), unsafe_allow_html=True)
        return num_tokens

if __name__=='__main__':

    # Config
    max_width = 1500
    padding_top = 0
    padding_right = 2
    padding_bottom = 0
    padding_left = 2

    define_margins = f"""
    <style>
        .appview-container .main .block-container{{
            max-width: {max_width}px;
            padding-top: {padding_top}rem;
            padding-right: {padding_right}rem;
            padding-left: {padding_left}rem;
            padding-bottom: {padding_bottom}rem;
        }}
    </style>
    """
    hide_table_row_index = """
                <style>
                tbody th {display:none}
                .blank {display:none}
                </style>
                """
    st.markdown(define_margins, unsafe_allow_html=True)
    st.markdown(hide_table_row_index, unsafe_allow_html=True)

    # Title
    st.markdown(generate_markdown('Tokenizer Demo:',size=32), unsafe_allow_html=True)
    st.markdown(generate_markdown('quick and easy way to explore how tokenizers work',size=24), unsafe_allow_html=True)

    # Select and load the tokenizer
    tokenizer_name = st.sidebar.selectbox('Choose the tokenizer from below',
                                            ('bert-base-uncased','bert-large-cased',
                                            'gpt2','gpt2-large',
                                            'roberta-base','roberta-large',
                                            'albert-base-v2','albert-xxlarge-v2'),index=7)
    tokenizer = load_model(tokenizer_name)

    comparison_mode = st.sidebar.checkbox('Compare two texts')
    detokenize = st.sidebar.checkbox('de-tokenize')
    if comparison_mode:
        sent_cols = st.columns(2)
        num_tokens = {}
        sents = {}
        for sent_id, sent_col in enumerate(sent_cols):
            with sent_col:
                sentence = st.text_input(f'Text {sent_id+1}')
                sents[f'sent_{sent_id+1}'] = sentence
                if detokenize:
                    num_tokens[f'sent_{sent_id+1}'] = DeTokenizeText(sentence)
                else:
                    num_tokens[f'sent_{sent_id+1}'] = TokenizeText(sentence,tokenizer_name)

        if len(sents['sent_1'])>0 and len(sents['sent_2'])>0:
            st.markdown(generate_markdown('Result&colon; ',size=16), unsafe_allow_html=True)
            if num_tokens[f'sent_1']==num_tokens[f'sent_2']:
                st.markdown(generate_markdown('Matched! ',color='MediumAquamarine'), unsafe_allow_html=True)
            else:
                st.markdown(generate_markdown('Not Matched... ',color='Salmon'), unsafe_allow_html=True)

    else:
        sentence = st.text_input(f'Text')
        if detokenize:
            num_tokens = DeTokenizeText(sentence)
        else:
            num_tokens = TokenizeText(sentence,tokenizer_name)