Spaces:
Runtime error
Runtime error
Commit
Β·
e15c8b9
1
Parent(s):
a812db5
improve effeciency
Browse files
app.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from transformers import pipeline
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
-
from nltk.corpus import stopwords
|
| 6 |
import nltk
|
| 7 |
import string
|
| 8 |
from streamlit.components.v1 import html
|
|
@@ -78,18 +77,19 @@ def find_source(text, docs):
|
|
| 78 |
@st.experimental_singleton
|
| 79 |
def init_models():
|
| 80 |
nltk.download('stopwords')
|
|
|
|
| 81 |
stop = set(stopwords.words('english') + list(string.punctuation))
|
| 82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
question_answerer = pipeline(
|
| 84 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
| 85 |
device=device
|
| 86 |
)
|
| 87 |
-
reranker = CrossEncoder('cross-encoder/ms-marco-
|
| 88 |
-
queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 89 |
-
queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 90 |
-
return question_answerer, reranker, stop, device
|
| 91 |
|
| 92 |
-
qa_model, reranker, stop, device
|
| 93 |
|
| 94 |
|
| 95 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -157,27 +157,27 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
| 157 |
use_reranking = st.radio(
|
| 158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 159 |
('yes', 'no'))
|
| 160 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200
|
| 161 |
-
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25
|
| 162 |
use_query_exp = st.radio(
|
| 163 |
"(Experimental) use query expansion? Right now it just recommends queries",
|
| 164 |
('yes', 'no'))
|
| 165 |
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
| 166 |
|
| 167 |
-
def paraphrase(text, max_length=128):
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
|
| 174 |
def run_query(query):
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
If you are not getting good results try one of:
|
| 179 |
-
* {query_exp}
|
| 180 |
-
""")
|
| 181 |
limit = top_hits_limit or 100
|
| 182 |
context_limit = context_lim or 10
|
| 183 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from transformers import pipeline
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
|
|
|
| 5 |
import nltk
|
| 6 |
import string
|
| 7 |
from streamlit.components.v1 import html
|
|
|
|
| 77 |
@st.experimental_singleton
|
| 78 |
def init_models():
|
| 79 |
nltk.download('stopwords')
|
| 80 |
+
from nltk.corpus import stopwords
|
| 81 |
stop = set(stopwords.words('english') + list(string.punctuation))
|
| 82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
question_answerer = pipeline(
|
| 84 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
| 85 |
device=device
|
| 86 |
)
|
| 87 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', device=device)
|
| 88 |
+
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 89 |
+
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 90 |
+
return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer
|
| 91 |
|
| 92 |
+
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
| 93 |
|
| 94 |
|
| 95 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 157 |
use_reranking = st.radio(
|
| 158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 159 |
('yes', 'no'))
|
| 160 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200)
|
| 161 |
+
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
| 162 |
use_query_exp = st.radio(
|
| 163 |
"(Experimental) use query expansion? Right now it just recommends queries",
|
| 164 |
('yes', 'no'))
|
| 165 |
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
| 166 |
|
| 167 |
+
# def paraphrase(text, max_length=128):
|
| 168 |
+
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
| 169 |
+
# generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
|
| 170 |
+
# queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
|
| 171 |
+
# preds = '\n * '.join(queries)
|
| 172 |
+
# return preds
|
| 173 |
|
| 174 |
def run_query(query):
|
| 175 |
+
# if use_query_exp == 'yes':
|
| 176 |
+
# query_exp = paraphrase(f"question2question: {query}")
|
| 177 |
+
# st.markdown(f"""
|
| 178 |
+
# If you are not getting good results try one of:
|
| 179 |
+
# * {query_exp}
|
| 180 |
+
# """)
|
| 181 |
limit = top_hits_limit or 100
|
| 182 |
context_limit = context_lim or 10
|
| 183 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|