File size: 6,040 Bytes
900edd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import logging
import os
import uuid
from datetime import datetime, timezone
from urllib.parse import quote_plus

import gradio as gr
import pandas as pd
import pymongo
from pymongo import MongoClient

from buster.completers import Completion, UserInputs
from buster.tokenizers import Tokenizer

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class WordTokenizer(Tokenizer):
    """Naive word-level tokenizer

    The original tokenizer from openAI eats way too much Ram.
    This is a naive word count tokenizer to be used instead."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode(self, string):
        return string.split()

    def decode(self, encoded):
        return " ".join(encoded)


def get_logging_db_name(instance_type: str) -> str:
    assert instance_type in ["dev", "prod", "local", "test"], "Invalid instance_type declared."
    return f"ai4h-databank-{instance_type}"


def get_session_id() -> str:
    """Generate a uuid for each user."""
    return str(uuid.uuid1())


def verify_required_env_vars(required_vars: list[str]):
    unset_vars = [var for var in required_vars if os.getenv(var) is None]
    if len(unset_vars) > 0:
        logger.warning(f"Lisf of env. variables that weren't set: {unset_vars}")
    else:
        logger.info("All environment variables are set appropriately.")


def make_uri(username: str, password: str, cluster: str) -> str:
    """Create mongodb uri."""
    uri = (
        "mongodb+srv://"
        + quote_plus(username)
        + ":"
        + quote_plus(password)
        + "@"
        + cluster
        + "/?retryWrites=true&w=majority"
    )
    return uri


def init_db(mongo_uri: str, db_name: str) -> pymongo.database.Database:
    """
    Initialize and return a connection to the specified MongoDB database.

    Parameters:
    - mongo_uri (str): The connection string for the MongoDB. This can be formed using `make_uri` function.
    - db_name (str): The name of the MongoDB database to connect to.

    Returns:
    pymongo.database.Database: The connected database object.

    Note:
    If there's a problem with the connection, an exception will be logged and the program will terminate.
    """

    try:
        mongodb_client = MongoClient(mongo_uri)
        # Ping the database to make sure authentication is good
        mongodb_client.admin.command("ping")
        database = mongodb_client[db_name]
        logger.info("Succesfully connected to the MongoDB database")
        return database
    except Exception as e:
        logger.exception("Something went wrong connecting to mongodb")


def get_utc_time() -> str:
    return str(datetime.now(timezone.utc))


def check_auth(username: str, password: str) -> bool:
    """Check if authentication succeeds or not.

    The authentication leverages the built-in gradio authentication. We use a shared password among users.
    It is temporary for developing the PoC. Proper authentication needs to be implemented in the future.
    We allow a valid username to be any username beginning with 'databank-', this will allow us to differentiate between users easily.
    """

    # get auth information from env. vars, they need to be set
    USERNAME = os.environ["AI4H_APP_USERNAME"]
    PASSWORD = os.environ["AI4H_APP_PASSWORD"]

    valid_user = username.startswith(USERNAME)
    valid_password = password == PASSWORD
    is_auth = valid_user and valid_password
    logger.info(f"Log-in attempted by {username=}. {is_auth=}")
    return is_auth


def format_sources(matched_documents: pd.DataFrame) -> list[str]:
    formatted_sources = []

    # We first group on Title of the document, so that 2 chunks from a same doc get lumped together
    grouped_df = matched_documents.groupby("title")

    # Here we just rank the titles by highest to lowest similarity score...
    ranked_titles = (
        grouped_df.apply(lambda x: x.similarity_to_answer.max()).sort_values(ascending=False).index.to_list()
    )

    for title in ranked_titles:
        df = grouped_df.get_group(title)

        # Adds a link break between sources from a same chunk
        chunks = "<br><br>".join(["πŸ”— " + chunk for chunk in df.content.to_list()])

        url = df.url.to_list()[0]
        source = df.source.to_list()[0]
        year = df.year.to_list()[0]
        country = df.country.to_list()[0]

        formatted_sources.append(
            f"""

### Publication: [{title}]({url})
**Year of publication:** {year}
**Source:** {source}
**Country:** {country}

**Identified sections**:
{chunks}
"""
        )

    return formatted_sources


def pad_sources(sources: list[str], max_sources: int) -> list[str]:
    """Pad sources with empty strings to ensure that the number of sources is always max_sources."""
    k = len(sources)
    return sources + [""] * (max_sources - k)


def add_sources(completion, max_sources: int):
    if not completion.question_relevant:
        # Question was not relevant, don't bother doing anything else...
        formatted_sources = [""]
    else:
        formatted_sources = format_sources(completion.matched_documents)

    formatted_sources = pad_sources(formatted_sources, max_sources)

    sources_textboxes = []
    for source in formatted_sources:
        visible = False if source == "" else True
        t = gr.Markdown(source, latex_delimiters=[], elem_classes="source", visible=visible)
        sources_textboxes.append(t)
    return sources_textboxes


def debug_completion(user_input, reformulate_question):
    """Generate a debug completion."""
    user_inputs = UserInputs(original_input=user_input)
    if reformulate_question:
        user_inputs.reformulated_input = "This is your reformulated question?"

    completion = Completion(
        user_inputs=user_inputs,
        error=False,
        matched_documents=[],
        answer_generator="This is the answer you'd expect a User to see.",
        question_relevant=True,
        answer_relevant=True,
    )
    return completion