Spaces:
Runtime error
Runtime error
Commit
·
154ca7b
1
Parent(s):
b6436ac
[General] Initial commit
Browse files- .gitignore +4 -0
- .streamlit/config.toml +6 -0
- app.py +62 -0
- modules/prediction/ERCBCM.py +13 -0
- modules/prediction/__init__.py +36 -0
- modules/prediction/model_loader.py +35 -0
- modules/tokenizer/__init__.py +15 -0
- requirements.txt +5 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
|
| 3 |
+
venv/
|
| 4 |
+
__pycache__/
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor="#262730"
|
| 3 |
+
backgroundColor="#ffffff"
|
| 4 |
+
secondaryBackgroundColor="#f6f6f8"
|
| 5 |
+
textColor="#090909"
|
| 6 |
+
font="sans serif"
|
app.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from modules.prediction import prepare, predict
|
| 4 |
+
|
| 5 |
+
STATUS_STOPPED = 120001
|
| 6 |
+
STATUS_SUBMIT = 120002
|
| 7 |
+
STATUS_ERROR = 120003
|
| 8 |
+
|
| 9 |
+
has_prepared = False
|
| 10 |
+
|
| 11 |
+
st.session_state['running_status'] = STATUS_STOPPED
|
| 12 |
+
|
| 13 |
+
if not has_prepared:
|
| 14 |
+
print('>>> [PREPARE] Preparing...')
|
| 15 |
+
prepare()
|
| 16 |
+
has_prepared = True
|
| 17 |
+
|
| 18 |
+
st.title('Entity Referring Classifier')
|
| 19 |
+
st.caption('It knows exactly when you are calling it. - Version 2.0.1208.01')
|
| 20 |
+
|
| 21 |
+
st.markdown('---')
|
| 22 |
+
|
| 23 |
+
livedemo_col1, livedemo_col2, livedemo_col3 = st.columns([12,1,6])
|
| 24 |
+
|
| 25 |
+
with livedemo_col1:
|
| 26 |
+
st.subheader('Live Demo')
|
| 27 |
+
|
| 28 |
+
with st.form("my_form"):
|
| 29 |
+
entity = st.text_input('Entity Name', 'Jimmy')
|
| 30 |
+
sentence = st.text_input('Text Input', 'Hey Jimmy.',
|
| 31 |
+
help='The classifier is going to analyze this sentence.')
|
| 32 |
+
if st.form_submit_button('Submit it'):
|
| 33 |
+
st.session_state['running_status'] = STATUS_SUBMIT
|
| 34 |
+
|
| 35 |
+
if st.session_state['running_status'] == STATUS_STOPPED:
|
| 36 |
+
st.info('Type something and submit to start!')
|
| 37 |
+
elif st.session_state['running_status'] == STATUS_SUBMIT:
|
| 38 |
+
if predict(sentence, entity) == 'CALLING':
|
| 39 |
+
st.success('It is a **calling**!')
|
| 40 |
+
else:
|
| 41 |
+
st.success('It is a **mentioning**!')
|
| 42 |
+
|
| 43 |
+
with livedemo_col2:
|
| 44 |
+
st.empty()
|
| 45 |
+
|
| 46 |
+
with livedemo_col3:
|
| 47 |
+
st.markdown("""
|
| 48 |
+
#### Get Started
|
| 49 |
+
""")
|
| 50 |
+
st.markdown("""
|
| 51 |
+
Hi! I'm the Entity Referring Classifier.
|
| 52 |
+
I can help you find out when you are calling it.
|
| 53 |
+
""")
|
| 54 |
+
st.markdown("""
|
| 55 |
+
#### Terms
|
| 56 |
+
""")
|
| 57 |
+
st.markdown("""
|
| 58 |
+
##### `Calling`
|
| 59 |
+
""")
|
| 60 |
+
st.markdown("""
|
| 61 |
+
##### `Mentioning`
|
| 62 |
+
""")
|
modules/prediction/ERCBCM.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from transformers import BertForSequenceClassification
|
| 3 |
+
|
| 4 |
+
class ERCBCM(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super(ERCBCM, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
| 10 |
+
|
| 11 |
+
def forward(self, text, label):
|
| 12 |
+
loss, text_fea = self.encoder(text, labels=label)[:2]
|
| 13 |
+
return loss, text_fea
|
modules/prediction/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
myPath = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.insert(0, myPath + '/../../')
|
| 5 |
+
|
| 6 |
+
# ==========
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from modules.prediction.model_loader import load_checkpoint
|
| 11 |
+
from modules.prediction.ERCBCM import ERCBCM
|
| 12 |
+
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
| 13 |
+
|
| 14 |
+
erc_root_folder = './model'
|
| 15 |
+
|
| 16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
+
|
| 18 |
+
# ==========
|
| 19 |
+
|
| 20 |
+
model_for_evaluate = ERCBCM().to(device)
|
| 21 |
+
|
| 22 |
+
def prepare():
|
| 23 |
+
load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)
|
| 24 |
+
|
| 25 |
+
def predict(sentence, name):
|
| 26 |
+
label = torch.tensor([0])
|
| 27 |
+
label = label.type(torch.LongTensor)
|
| 28 |
+
label = label.to(device)
|
| 29 |
+
text = tokenizer.encode(normalize_v2(sentence, name))
|
| 30 |
+
text += [PAD_TOKEN_ID] * (128 - len(text))
|
| 31 |
+
text = torch.tensor([text])
|
| 32 |
+
text = text.type(torch.LongTensor)
|
| 33 |
+
text = text.to(device)
|
| 34 |
+
_, output = model_for_evaluate(text, label)
|
| 35 |
+
pred = torch.argmax(output, 1).tolist()[0]
|
| 36 |
+
return 'CALLING' if pred == 1 else 'MENTIONING'
|
modules/prediction/model_loader.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Save and Load Functions
|
| 4 |
+
|
| 5 |
+
def save_checkpoint(save_path, model, valid_loss):
|
| 6 |
+
if save_path == None:
|
| 7 |
+
return
|
| 8 |
+
state_dict = {'model_state_dict': model.state_dict(),
|
| 9 |
+
'valid_loss': valid_loss}
|
| 10 |
+
torch.save(state_dict, save_path)
|
| 11 |
+
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
| 12 |
+
|
| 13 |
+
def load_checkpoint(load_path, model, device):
|
| 14 |
+
if load_path == None:
|
| 15 |
+
return
|
| 16 |
+
state_dict = torch.load(load_path, map_location=device)
|
| 17 |
+
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
| 18 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
| 19 |
+
return state_dict['valid_loss']
|
| 20 |
+
|
| 21 |
+
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
| 22 |
+
if save_path == None:
|
| 23 |
+
return
|
| 24 |
+
state_dict = {'train_loss_list': train_loss_list,
|
| 25 |
+
'valid_loss_list': valid_loss_list,
|
| 26 |
+
'global_steps_list': global_steps_list}
|
| 27 |
+
torch.save(state_dict, save_path)
|
| 28 |
+
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
| 29 |
+
|
| 30 |
+
def load_metrics(load_path, device):
|
| 31 |
+
if load_path == None:
|
| 32 |
+
return
|
| 33 |
+
state_dict = torch.load(load_path, map_location=device)
|
| 34 |
+
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
| 35 |
+
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
modules/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BertTokenizer
|
| 2 |
+
|
| 3 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 4 |
+
|
| 5 |
+
# Parameters preparation.
|
| 6 |
+
MAX_SENT_LENGTH = 128
|
| 7 |
+
PAD_TOKEN_ID = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
| 8 |
+
|
| 9 |
+
def normalize_v2(text, entity):
|
| 10 |
+
text = text.lower()
|
| 11 |
+
entity = entity.lower()
|
| 12 |
+
if entity not in text:
|
| 13 |
+
return text
|
| 14 |
+
text = text.replace(entity, tokenizer.mask_token) # TODO: not sure if this will be decoded by BERT.
|
| 15 |
+
return text
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
torch
|
| 3 |
+
torchtext
|
| 4 |
+
ipywidgets
|
| 5 |
+
transformers
|