Marvin M. Agüero-Torales
adding multilingual, spanish an english base models
848129c
raw
history blame
4.48 kB
import streamlit as st
import pandas as pd
from streamlit import cli as stcli
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import sys
HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def get_model(model):
return pipeline("fill-mask", model=model, top_k=10)#s5t the maximum of tokens to be retrieved after each inference to model
def hash_func(inp):
return True
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def loading_models(model='roberta-base'):
return get_model(model), SentenceTransformer('all-mpnet-base-v2')#'all-MiniLM-L6-v2')
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def infer(text):
# global nlp
return nlp(text+' '+nlp.tokenizer.mask_token)
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def sim(predicted_seq, sem_list):
return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
semantic_model.encode(sem_list, convert_to_tensor=True)
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def main(text,semantic_text,history_keyword_text):
global semantic_model, data_load_state
data_load_state.text('Inference from model...')
result = infer(text)
sem_list=[semantic_text.strip()]
data_load_state.text('Checking similarity...')
if len(semantic_text):
predicted_seq=[rec['sequence'] for rec in result]
predicted_embeddings, semantic_history_embeddings = sim(predicted_seq, sem_list)
cosine_scores = util.cos_sim(predicted_embeddings, semantic_history_embeddings)
data_load_state.text('similarity check completed...')
for index, r in enumerate(result):
if len(semantic_text):
if len(r['token_str'])>2: #skip spcial chars such as "?"
result[index]['score']+=float(sum(cosine_scores[index]))*HISTORY_WEIGHT
if r['token_str'].lower().strip() in history_keyword_text.lower().strip() and len(r['token_str'].lower().strip())>1:
#found from history, then increase the score of tokens
result[index]['score']*=HISTORY_WEIGHT
data_load_state.text('Score updated...')
#sort the results
df=pd.DataFrame(result).sort_values(by='score', ascending=False)
return df
if __name__ == '__main__':
if st._is_running_with_streamlit:
st.markdown("""
# Auto-Complete
This is an example of an auto-complete approach where the next token suggested based on users's history Keyword match & Semantic similarity of users's history (log).
The next token is predicted per probability and a weight if it is appeared in keyword user's history or there is a similarity to semantic user's history.
**Forked from [mbahrami/Auto-Complete_Semantic](https://huggingface.co/spaces/mbahrami/Auto-Complete_Semantic).**
""")
history_keyword_text = st.text_input("Enter users's history <Keywords Match> (optional, i.e., 'Premio Cervantes')", value="")
semantic_text = st.text_input("Enter users's history <Semantic> (optional, i.e., 'hai')", value="hai")
text = st.text_input("Enter a text for auto completion...", value="Augusto Roa Bastos ha'e kuimba'e arandu")
model = st.selectbox("Choose a model", ["mmaguero/gn-bert-tiny-cased", "mmaguero/gn-bert-small-cased", "mmaguero/gn-bert-base-cased", "mmaguero/gn-bert-large-cased", "mmaguero/multilingual-bert-gn-base-cased", "mmaguero/beto-gn-base-cased", "bert-base-multilingual-cased", "xlm-roberta-base", "dccuchile/bert-base-spanish-wwm-cased", "PlanTL-GOB-ES/roberta-base-bne", "bert-base-cased", "roberta-base"])
data_load_state = st.text('1.Loading model ...')
nlp, semantic_model = loading_models(model)
df=main(text,semantic_text,history_keyword_text)
#show the results as a table
st.table(df)
data_load_state.text('')
else:
sys.argv = ['streamlit', 'run', sys.argv[0]]
sys.exit(stcli.main())