File size: 4,214 Bytes
1f22b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sqlite3
import warnings
import zlib

import numpy as np
import pandas as pd

documents_table = """CREATE TABLE IF NOT EXISTS documents (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    source TEXT NOT NULL,
    title TEXT NOT NULL,
    url TEXT NOT NULL,
    content TEXT NOT NULL,
    n_tokens INTEGER,
    embedding BLOB,
    current INTEGER
)"""

qa_table = """CREATE TABLE IF NOT EXISTS qa (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    source TEXT NOT NULL,
    prompt TEXT NOT NULL,
    answer TEXT NOT NULL,
    document_id_1 INTEGER,
    document_id_2 INTEGER,
    document_id_3 INTEGER,
    label_question INTEGER,
    label_answer INTEGER,
    testset INTEGER,
    FOREIGN KEY (document_id_1) REFERENCES documents (id),
    FOREIGN KEY (document_id_2) REFERENCES documents (id),
    FOREIGN KEY (document_id_3) REFERENCES documents (id)
)"""


class DocumentsDB:
    """Simple SQLite database for storing documents and questions/answers.

    The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
    Questions/answers refer to the version of the document that was used at the time.

    Example:
        >>> db = DocumentsDB("/path/to/the/db.db")
        >>> db.write_documents("source", df)  # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
        >>> df = db.get_documents("source")
    """

    def __init__(self, db_path):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()

        self.__initialize()

    def __del__(self):
        self.conn.close()

    def __initialize(self):
        """Initialize the database."""
        self.cursor.execute(documents_table)
        self.cursor.execute(qa_table)
        self.conn.commit()

    def write_documents(self, source: str, df: pd.DataFrame):
        """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
        df = df.copy()

        # Prepare the rows
        df["source"] = source
        df["current"] = 1
        columns = ["source", "title", "url", "content", "current"]
        if "embedding" in df.columns:
            columns.extend(
                [
                    "n_tokens",
                    "embedding",
                ]
            )

            # Check that the embeddings are float32
            if not df["embedding"].iloc[0].dtype == np.float32:
                warnings.warn(
                    f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
                    RuntimeWarning,
                )
                df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))

            # ZLIB compress the embeddings
            df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))

        data = df[columns].values.tolist()

        # Set `current` to 0 for all previous documents from that source
        self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))

        # Insert the new documents
        insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
        self.cursor.executemany(insert_statement, data)

        self.conn.commit()

    def get_documents(self, source: str) -> pd.DataFrame:
        """Get all current documents from a given source."""
        # Execute the SQL statement and fetch the results
        results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
        rows = results.fetchall()

        # Convert the results to a pandas DataFrame
        df = pd.DataFrame(rows, columns=[description[0] for description in results.description])

        # ZLIB decompress the embeddings
        df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist())

        # Drop the `current` column
        df.drop(columns=["current"], inplace=True)

        return df