File size: 4,136 Bytes
cd6bb4b
5bf3195
d631df4
5bf3195
 
c8b8b02
 
5bf3195
75a550e
5bf3195
c8b8b02
 
5bf3195
75a550e
5bf3195
c8b8b02
045b802
d631df4
5bf3195
d631df4
 
5bf3195
 
d631df4
5bf3195
 
d631df4
5bf3195
 
 
 
d631df4
5bf3195
 
d631df4
5bf3195
 
 
 
 
 
 
 
 
 
 
 
d631df4
5bf3195
 
c8b8b02
ac8d16a
d631df4
5bf3195
 
 
d631df4
 
 
 
5bf3195
 
 
 
 
d631df4
 
5bf3195
d631df4
5bf3195
cd6bb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bf3195
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pinecone
import streamlit as st
from transformers import pipeline, AutoTokenizer
from sentence_transformers import SentenceTransformer

PINECONE_KEY = st.secrets["PINECONE_API_KEY"]  # app.pinecone.io
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"]  # app.pinecone.io

@st.cache_resource
def init_pinecone():
    pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)  # get a free api key from app.pinecone.io
    return pinecone.Index("dompany-description")
    
@st.cache_resource
def init_models():
    #retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return retriever, reader, tokenizer

st.session_state.index = init_pinecone()
retriever, reader, tokenizer = init_models()


def card(name, description, score):
    return st.markdown(f"""
    <div class="container-fluid">
        <div class="row align-items-start">
             <div  class="col-md-12 col-sm-12">
                 <b>{name}</b>
                 <br>
                 <span style="color: #808080;">
                     <small>{description}</small>
                     [<b>Score: </b>{score}]
                 </span>
             </div>
        </div>
     </div>
        """, unsafe_allow_html=True)

st.title("")

def run_query(query):
    xq = retriever.encode([query]).tolist()
    try:
        xc = st.session_state.index.query(xq, top_k=3, include_metadata=True, include_vectors = True)
    except:
        # force reload
        pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
        st.session_state.index = pinecone.Index("company-description")
        xc = st.session_state.index.query(xq, top_k=10, include_metadata=True, include_vectors = True)

    results = []
    for match in xc['matches']:
        #answer = reader(question=query, context=match["metadata"]['context'])
        answer = {'score': match['score']}
        answer["name"] = match["metadata"]['company_name'].strip('_description')
        answer["description"] = match["metadata"]['description']
        results.append(answer)

    sorted_result = sorted(results, key=lambda x: x['score'], reverse=True)

    for r in sorted_result:
        company_name = r["name"]
        description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
        score = round(r["score"], 4)
        card(company_name, description, score)

def check_password():
    """Returns `True` if the user had the correct password."""

    def password_entered():
        """Checks whether a password entered by the user is correct."""
        if st.session_state["password"] == st.secrets["password"]:
            st.session_state["password_correct"] = True
            del st.session_state["password"]  # don't store password
        else:
            st.session_state["password_correct"] = False

    if "password_correct" not in st.session_state:
        # First run, show input for password.
        st.text_input(
            "Password", type="password", on_change=password_entered, key="password"
        )
        return False
    elif not st.session_state["password_correct"]:
        # Password not correct, show input + error.
        st.text_input(
            "Password", type="password", on_change=password_entered, key="password"
        )
        st.error("😕 Password incorrect")
        return False
    else:
        # Password correct.
        return True

if check_password():
    st.write("""
    Search for a company in free text
    """)

    st.markdown("""
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    """, unsafe_allow_html=True)
    
    query = st.text_input("Search!", "")

    if query != "":
        run_query(query)