Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
·
a41bdbc
1
Parent(s):
49438d6
Multi model select and local model loading
Browse files- __init__.py +0 -0
- app.py +12 -30
- backend/__init__.py +0 -0
- backend/config.py +1 -0
- backend/inference.py +9 -20
- backend/main.py +0 -19
- backend/utils.py +11 -0
- requirements.txt +1 -1
__init__.py
ADDED
|
File without changes
|
app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
-
|
| 4 |
-
import
|
|
|
|
| 5 |
|
| 6 |
st.title('Demo using Flax-Sentence-Tranformers')
|
| 7 |
|
|
@@ -20,12 +21,12 @@ For more cool information on sentence embeddings, see the [sBert project](https:
|
|
| 20 |
Please enjoy!!
|
| 21 |
''')
|
| 22 |
|
| 23 |
-
|
| 24 |
anchor = st.text_input(
|
| 25 |
'Please enter here the main text you want to compare:'
|
| 26 |
)
|
| 27 |
|
| 28 |
if anchor:
|
|
|
|
| 29 |
n_texts = st.sidebar.number_input(
|
| 30 |
f'''How many texts you want to compare with: '{anchor}'?''',
|
| 31 |
value=2,
|
|
@@ -34,40 +35,21 @@ if anchor:
|
|
| 34 |
inputs = []
|
| 35 |
|
| 36 |
for i in range(n_texts):
|
| 37 |
-
|
| 38 |
-
input = st.sidebar.text_input(f'Text {i+1}:')
|
| 39 |
|
| 40 |
inputs.append(input)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
api_base_url = 'http://127.0.0.1:8000/similarity'
|
| 45 |
-
|
| 46 |
if anchor:
|
| 47 |
if st.sidebar.button('Tell me the similarity.'):
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
inputs = inputs,
|
| 53 |
-
model = 'mpnet'))
|
| 54 |
-
res_minilm_l6 = requests.get(url = api_base_url, params = dict(anchor = anchor,
|
| 55 |
-
inputs = inputs,
|
| 56 |
-
model = 'minilm_l6'))
|
| 57 |
-
|
| 58 |
-
d_distilroberta = res_distilroberta.json()['dataframe']
|
| 59 |
-
d_mpnet = res_mpnet.json()['dataframe']
|
| 60 |
-
d_minilm_l6 = res_minilm_l6.json()['dataframe']
|
| 61 |
-
|
| 62 |
-
index = list(d_distilroberta['inputs'].values())
|
| 63 |
df_total = pd.DataFrame(index=index)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
df_total['minilm_l6'] = list(d_minilm_l6['score'].values())
|
| 67 |
|
| 68 |
-
st.write('Here are the results for
|
| 69 |
st.write(df_total)
|
| 70 |
st.write('Visualize the results of each model:')
|
| 71 |
st.area_chart(df_total)
|
| 72 |
-
|
| 73 |
-
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from backend import inference
|
| 5 |
+
from backend.config import MODELS_ID
|
| 6 |
|
| 7 |
st.title('Demo using Flax-Sentence-Tranformers')
|
| 8 |
|
|
|
|
| 21 |
Please enjoy!!
|
| 22 |
''')
|
| 23 |
|
|
|
|
| 24 |
anchor = st.text_input(
|
| 25 |
'Please enter here the main text you want to compare:'
|
| 26 |
)
|
| 27 |
|
| 28 |
if anchor:
|
| 29 |
+
select_models = st.sidebar.multiselect("Choose models", options=MODELS_ID.keys())
|
| 30 |
n_texts = st.sidebar.number_input(
|
| 31 |
f'''How many texts you want to compare with: '{anchor}'?''',
|
| 32 |
value=2,
|
|
|
|
| 35 |
inputs = []
|
| 36 |
|
| 37 |
for i in range(n_texts):
|
| 38 |
+
input = st.sidebar.text_input(f'Text {i + 1}:')
|
|
|
|
| 39 |
|
| 40 |
inputs.append(input)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if anchor:
|
| 43 |
if st.sidebar.button('Tell me the similarity.'):
|
| 44 |
+
results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
|
| 45 |
+
df_results = {model: results[model] for model in results}
|
| 46 |
+
|
| 47 |
+
index = inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
df_total = pd.DataFrame(index=index)
|
| 49 |
+
for key, value in df_results.items():
|
| 50 |
+
df_total[key] = list(value['score'].values)
|
|
|
|
| 51 |
|
| 52 |
+
st.write('Here are the results for selected models:')
|
| 53 |
st.write(df_total)
|
| 54 |
st.write('Visualize the results of each model:')
|
| 55 |
st.area_chart(df_total)
|
|
|
|
|
|
backend/__init__.py
ADDED
|
File without changes
|
backend/config.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
| 2 |
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
|
|
|
| 3 |
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
|
|
|
| 1 |
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
| 2 |
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
| 3 |
+
mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
|
| 4 |
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
backend/inference.py
CHANGED
|
@@ -1,41 +1,30 @@
|
|
| 1 |
-
from sentence_transformers import SentenceTransformer
|
| 2 |
import pandas as pd
|
| 3 |
import jax.numpy as jnp
|
| 4 |
|
| 5 |
from typing import List
|
| 6 |
-
import config
|
| 7 |
-
|
| 8 |
-
# We download the models we will be using.
|
| 9 |
-
# If you do not want to use all, you can comment the unused ones.
|
| 10 |
-
distilroberta_model = SentenceTransformer(config.MODELS_ID['distilroberta'])
|
| 11 |
-
mpnet_model = SentenceTransformer(config.MODELS_ID['mpnet'])
|
| 12 |
-
minilm_l6_model = SentenceTransformer(config.MODELS_ID['minilm_l6'])
|
| 13 |
|
| 14 |
# Defining cosine similarity using flax.
|
|
|
|
|
|
|
|
|
|
| 15 |
def cos_sim(a, b):
|
| 16 |
-
return jnp.matmul(a, jnp.transpose(b))/(jnp.linalg.norm(a)*jnp.linalg.norm(b))
|
| 17 |
|
| 18 |
|
| 19 |
# We get similarity between embeddings.
|
| 20 |
-
def text_similarity(anchor: str, inputs: List[str],
|
|
|
|
| 21 |
|
| 22 |
# Creating embeddings
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
inputs_emb = distilroberta_model.encode([input for input in inputs])
|
| 26 |
-
elif model == 'mpnet':
|
| 27 |
-
anchor_emb = mpnet_model.encode(anchor)[None, :]
|
| 28 |
-
inputs_emb = mpnet_model.encode([input for input in inputs])
|
| 29 |
-
elif model == 'minilm_l6':
|
| 30 |
-
anchor_emb = minilm_l6_model.encode(anchor)[None, :]
|
| 31 |
-
inputs_emb = minilm_l6_model.encode([input for input in inputs])
|
| 32 |
|
| 33 |
# Obtaining similarity
|
| 34 |
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
| 35 |
|
| 36 |
# Returning a Pandas' dataframe
|
| 37 |
d = {'inputs': [input for input in inputs],
|
| 38 |
-
'score': [round(similarity[i],3) for i in range(len(similarity))]}
|
| 39 |
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
| 40 |
|
| 41 |
return df.sort_values('score', ascending=False)
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import jax.numpy as jnp
|
| 3 |
|
| 4 |
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Defining cosine similarity using flax.
|
| 7 |
+
from backend.utils import load_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
def cos_sim(a, b):
|
| 11 |
+
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
| 12 |
|
| 13 |
|
| 14 |
# We get similarity between embeddings.
|
| 15 |
+
def text_similarity(anchor: str, inputs: List[str], model_name: str):
|
| 16 |
+
model = load_model(model_name)
|
| 17 |
|
| 18 |
# Creating embeddings
|
| 19 |
+
anchor_emb = model.encode(anchor)[None, :]
|
| 20 |
+
inputs_emb = model.encode([input for input in inputs])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Obtaining similarity
|
| 23 |
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
| 24 |
|
| 25 |
# Returning a Pandas' dataframe
|
| 26 |
d = {'inputs': [input for input in inputs],
|
| 27 |
+
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
|
| 28 |
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
| 29 |
|
| 30 |
return df.sort_values('score', ascending=False)
|
backend/main.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
from fastapi import Query, FastAPI
|
| 2 |
-
|
| 3 |
-
import config
|
| 4 |
-
import inference
|
| 5 |
-
from typing import List
|
| 6 |
-
|
| 7 |
-
app = FastAPI()
|
| 8 |
-
|
| 9 |
-
@app.get("/")
|
| 10 |
-
def read_root():
|
| 11 |
-
return {"message": "Welcome to the API of flax-sentence-embeddings."}
|
| 12 |
-
|
| 13 |
-
@app.get('/similarity')
|
| 14 |
-
def get_similarity(anchor: str, inputs: List[str] = Query([]), model: str = 'distilroberta'):
|
| 15 |
-
return {'dataframe': inference.text_similarity(anchor, inputs, model)}
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
#if __name__ == "__main__":
|
| 19 |
-
# uvicorn.run("main:app", host="0.0.0.0", port=8080)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from .config import MODELS_ID
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@st.cache(allow_output_mutation=True)
|
| 7 |
+
def load_model(model_name):
|
| 8 |
+
assert model_name in MODELS_ID.keys()
|
| 9 |
+
# Lazy downloading
|
| 10 |
+
model = SentenceTransformer(MODELS_ID[model_name])
|
| 11 |
+
return model
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
fastapi
|
| 2 |
sentence_transformers
|
| 3 |
pandas
|
| 4 |
jax
|
|
|
|
| 5 |
streamlit
|
|
|
|
|
|
|
| 1 |
sentence_transformers
|
| 2 |
pandas
|
| 3 |
jax
|
| 4 |
+
jaxlib
|
| 5 |
streamlit
|