Marc-Antoine Rondeau commited on
Commit
97aefb5
·
1 Parent(s): 71e7dd8

New db schema

Browse files
buster/db/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .documents import DocumentsDB
2
+
3
+ __all__ = [DocumentsDB]
buster/db/backward.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Used to import existing DB as a new DB."""
2
+
3
+ import argparse
4
+ import itertools
5
+ from typing import Iterable, NamedTuple
6
+
7
+ import numpy as np
8
+ import sqlite3
9
+
10
+ from buster.db import DocumentsDB
11
+
12
+ import buster.db.documents as dest
13
+
14
+
15
+ IMPORT_QUERY = (
16
+ r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
17
+ )
18
+ CHUNK_QUERY = r"""SELECT source, url, title, content, n_tokens, embedding FROM documents WHERE current = 1 ORDER BY source, url, id"""
19
+
20
+
21
+ class Document(NamedTuple):
22
+ """Document from the original db."""
23
+
24
+ source: str
25
+ url: str
26
+ title: str
27
+ content: str
28
+
29
+
30
+ class Section(NamedTuple):
31
+ """Reassemble section from the original db."""
32
+
33
+ url: str
34
+ title: str
35
+ content: str
36
+
37
+
38
+ class Chunk(NamedTuple):
39
+ """Chunk from the original db."""
40
+
41
+ source: str
42
+ url: str
43
+ title: str
44
+ content: str
45
+ n_tokens: int
46
+ embedding: np.ndarray
47
+
48
+
49
+ def get_documents(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Section]]]:
50
+ """Reassemble documents from the source db's chunks."""
51
+ documents = (Document(*row) for row in conn.execute(IMPORT_QUERY))
52
+ by_sources = itertools.groupby(documents, lambda doc: doc.source)
53
+ for source, documents in by_sources:
54
+ documents = itertools.groupby(documents, lambda doc: (doc.url, doc.title))
55
+ sections = (
56
+ Section(url, title, "".join(chunk.content for chunk in chunks)) for (url, title), chunks in documents
57
+ )
58
+ yield source, sections
59
+
60
+
61
+ def get_max_size(conn: sqlite3.Connection) -> int:
62
+ """Get the maximum chunk size from the source db."""
63
+ sizes = (size for size, in conn.execute("select max(length(content)) FROM documents"))
64
+ (size,) = sizes
65
+ return size
66
+
67
+
68
+ def get_chunks(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Iterable[dest.Chunk]]]]:
69
+ """Retrieve chunks from the source db."""
70
+ chunks = (Chunk(*row) for row in conn.execute(CHUNK_QUERY))
71
+ by_sources = itertools.groupby(chunks, lambda chunk: chunk.source)
72
+ for source, chunks in by_sources:
73
+ by_section = itertools.groupby(chunks, lambda chunk: (chunk.url, chunk.title))
74
+
75
+ sections = (
76
+ (dest.Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks) for _, chunks in by_section
77
+ )
78
+
79
+ yield source, sections
80
+
81
+
82
+ def main():
83
+ """Import the source db into the destination db."""
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("source")
86
+ parser.add_argument("destination")
87
+ parser.add_argument("--size", type=int, default=2000)
88
+ args = parser.parse_args()
89
+ org = sqlite3.connect(args.source)
90
+ db = DocumentsDB(args.destination)
91
+
92
+ for source, content in get_documents(org):
93
+ sid, vid = db.start_version(source)
94
+ sections = (dest.Section(section.title, section.url, section.content) for section in content)
95
+ db.add_sections(sid, vid, sections)
96
+
97
+ size = max(args.size, get_max_size(org))
98
+ for source, chunks in get_chunks(org):
99
+ sid, vid = db.get_current_version(source)
100
+ cid = db.add_chunking(sid, vid, size)
101
+ db.add_chunks(sid, vid, cid, chunks)
102
+ db.conn.commit()
103
+
104
+ return
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
buster/db/documents.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import Iterable, NamedTuple
3
+ import warnings
4
+ import zlib
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ import buster.db.schema as schema
10
+
11
+
12
+ class Section(NamedTuple):
13
+ title: str
14
+ url: str
15
+ content: str
16
+ parent: int | None = None
17
+ type: str = "section"
18
+
19
+
20
+ class Chunk(NamedTuple):
21
+ content: str
22
+ n_tokens: int
23
+ emb: np.ndarray
24
+
25
+
26
+ class DocumentsDB:
27
+ """Simple SQLite database for storing documents and questions/answers.
28
+
29
+ The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
30
+ Questions/answers refer to the version of the document that was used at the time.
31
+
32
+ Example:
33
+ >>> db = DocumentsDB("/path/to/the/db.db")
34
+ >>> db.write_documents("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
35
+ >>> df = db.get_documents("source")
36
+ """
37
+
38
+ def __init__(self, db_path: sqlite3.Connection | str):
39
+ if isinstance(db_path, str):
40
+ self.db_path = db_path
41
+ self.conn = sqlite3.connect(db_path)
42
+ else:
43
+ self.db_path = None
44
+ self.conn = db_path
45
+ self.cursor = self.conn.cursor()
46
+ schema.initialize_db(self.conn)
47
+ schema.setup_db(self.conn)
48
+
49
+ def __del__(self):
50
+ if self.db_path is not None:
51
+ self.conn.close()
52
+
53
+ def get_current_version(self, source: str) -> tuple[int, int]:
54
+ cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
55
+ row = cur.fetchone()
56
+ if row is None:
57
+ raise KeyError(f'"{source}" is not a known source')
58
+ sid, vid = row
59
+ return sid, vid
60
+
61
+ def get_source(self, source: str) -> int:
62
+ cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
63
+ row = cur.fetchone()
64
+ if row is not None:
65
+ (sid,) = row
66
+ else:
67
+ cur = self.conn.execute("INSERT INTO sources (name) VALUES (?)", (source,))
68
+ cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
69
+ row = cur.fetchone()
70
+ (sid,) = row
71
+
72
+ return sid
73
+
74
+ def start_version(self, source: str) -> tuple[int, int]:
75
+ cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
76
+ row = cur.fetchone()
77
+ if row is None:
78
+ sid = self.get_source(source)
79
+ vid = 0
80
+ else:
81
+ sid, vid = row
82
+ vid = vid + 1
83
+ self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
84
+ return sid, vid
85
+
86
+ def add_sections(self, sid: int, vid: int, sections: Iterable[Section]):
87
+ values = (
88
+ (sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
89
+ for ind, section in enumerate(sections)
90
+ )
91
+ self.conn.executemany(
92
+ "INSERT INTO sections "
93
+ "(source, version, section, title, url, content, parent, type) "
94
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
95
+ values,
96
+ )
97
+ return
98
+
99
+ def add_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
100
+ self.conn.execute(
101
+ "INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
102
+ (size, overlap, strategy, sid, vid),
103
+ )
104
+ cur = self.conn.execute(
105
+ "SELECT chunking FROM chunkings "
106
+ "WHERE size = ? AND overlap = ? AND strategy = ? AND source = ? AND version = ?",
107
+ (size, overlap, strategy, sid, vid),
108
+ )
109
+ (id,) = (id for id, in cur)
110
+ return id
111
+
112
+ def add_chunks(self, sid: int, vid: int, cid: int, sections: Iterable[Iterable[Chunk]]):
113
+ chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
114
+ values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
115
+ self.conn.executemany(
116
+ "INSERT INTO chunks "
117
+ "(source, version, section, chunking, sequence, content, n_tokens, embedding) "
118
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
119
+ values,
120
+ )
121
+ return
122
+
123
+ def write_documents(self, source: str, df: pd.DataFrame):
124
+ """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
125
+ df = df.copy()
126
+
127
+ # Prepare the rows
128
+ df["source"] = source
129
+ df["current"] = 1
130
+ columns = ["source", "title", "url", "content", "current"]
131
+ if "embedding" in df.columns:
132
+ columns.extend(
133
+ [
134
+ "n_tokens",
135
+ "embedding",
136
+ ]
137
+ )
138
+
139
+ # Check that the embeddings are float32
140
+ if not df["embedding"].iloc[0].dtype == np.float32:
141
+ warnings.warn(
142
+ f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
143
+ RuntimeWarning,
144
+ )
145
+ df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
146
+
147
+ # ZLIB compress the embeddings
148
+ df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
149
+
150
+ data = df[columns].values.tolist()
151
+
152
+ # Set `current` to 0 for all previous documents from that source
153
+ self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
154
+
155
+ # Insert the new documents
156
+ insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
157
+ self.cursor.executemany(insert_statement, data)
158
+
159
+ self.conn.commit()
160
+
161
+ def get_documents(self, source: str) -> pd.DataFrame:
162
+ """Get all current documents from a given source."""
163
+ # Execute the SQL statement and fetch the results
164
+ results = self.cursor.execute("SELECT * FROM documents WHERE source = ?", (source,))
165
+ rows = results.fetchall()
166
+
167
+ # Convert the results to a pandas DataFrame
168
+ df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
169
+ return df
buster/db/schema.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zlib
2
+
3
+
4
+ import numpy as np
5
+ import sqlite3
6
+
7
+
8
+ SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
9
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
10
+ name TEXT NOT NULL,
11
+ note TEXT,
12
+ UNIQUE(name)
13
+ )"""
14
+
15
+
16
+ VERSION_TABLE = r"""CREATE TABLE IF NOT EXISTS versions (
17
+ source INTEGER,
18
+ version INTEGER,
19
+ parser TEXT,
20
+ note TEXT,
21
+ PRIMARY KEY (version, source, parser)
22
+ FOREIGN KEY (source) REFERENCES sources (id)
23
+ )"""
24
+
25
+
26
+ CHUNKING_TABLE = r"""CREATE TABLE IF NOT EXISTS chunkings (
27
+ chunking INTEGER PRIMARY KEY AUTOINCREMENT,
28
+ size INTEGER,
29
+ overlap INTEGER,
30
+ strategy TEXT,
31
+ chunker TEXT,
32
+ source INTEGER,
33
+ version INTEGER,
34
+ UNIQUE (size, overlap, strategy, chunker, source, version),
35
+ FOREIGN KEY (source, version) REFERENCES versions (source, version)
36
+ )"""
37
+
38
+
39
+ SECTION_TABLE = r"""CREATE TABLE IF NOT EXISTS sections (
40
+ source INTEGER,
41
+ version INTEGER,
42
+ section INTEGER,
43
+ title TEXT NOT NULL,
44
+ url TEXT NOT NULL,
45
+ content TEXT NOT NULL,
46
+ parent INTEGER,
47
+ type TEXT,
48
+ PRIMARY KEY (version, source, section),
49
+ FOREIGN KEY (source) REFERENCES versions (source),
50
+ FOREIGN KEY (version) REFERENCES versions (version)
51
+ )"""
52
+
53
+
54
+ CHUNK_TABLE = r"""CREATE TABLE IF NOT EXISTS chunks (
55
+ source INTEGER,
56
+ version INTEGER,
57
+ section INTEGER,
58
+ chunking INTEGER,
59
+ sequence INTEGER,
60
+ content TEXT NOT NULL,
61
+ n_tokens INTEGER,
62
+ embedding VECTOR,
63
+ PRIMARY KEY (source, version, section, chunking, sequence),
64
+ FOREIGN KEY (source, version, section) REFERENCES sections (source, version, section),
65
+ FOREIGN KEY (source, version, chunking) REFERENCES chunkings (source, version, chunking)
66
+ )"""
67
+
68
+
69
+ VERSION_VIEW = r"""CREATE VIEW IF NOT EXISTS latest_version (
70
+ name, source, version) AS
71
+ SELECT sources.name, versions.source, max(versions.version)
72
+ FROM sources INNER JOIN versions on sources.id = versions.source
73
+ GROUP BY sources.id
74
+ """
75
+
76
+ CHUNKING_VIEW = r"""CREATE VIEW IF NOT EXISTS latest_chunking (
77
+ name, source, version, chunking) AS
78
+ SELECT name, source, version, max(chunking) FROM
79
+ chunkings INNER JOIN latest_version USING (source, version)
80
+ GROUP by source, version
81
+ """
82
+
83
+ DOCUMENT_VIEW = r"""CREATE VIEW IF NOT EXISTS documents (
84
+ source, title, url, content, n_tokens, embedding)
85
+ AS SELECT latest_chunking.name, sections.title, sections.url,
86
+ chunks.content, chunks.n_tokens, chunks.embedding
87
+ FROM chunks INNER JOIN sections USING (source, version, section)
88
+ INNER JOIN latest_chunking USING (source, version, chunking)
89
+ """
90
+
91
+
92
+ INIT_STATEMENTS = [
93
+ SOURCE_TABLE,
94
+ VERSION_TABLE,
95
+ CHUNKING_TABLE,
96
+ SECTION_TABLE,
97
+ CHUNK_TABLE,
98
+ VERSION_VIEW,
99
+ CHUNKING_VIEW,
100
+ DOCUMENT_VIEW,
101
+ ]
102
+
103
+
104
+ def initialize_db(connection: sqlite3.Connection):
105
+ for statement in INIT_STATEMENTS:
106
+ try:
107
+ connection.execute(statement)
108
+ except sqlite3.Error as error:
109
+ connection.rollback()
110
+ raise
111
+ connection.commit()
112
+ return connection
113
+
114
+
115
+ def adapt_vector(vector: np.ndarray) -> bytes:
116
+ return sqlite3.Binary(zlib.compress(vector.astype(np.float32).tobytes()))
117
+
118
+
119
+ def convert_vector(buffer: bytes) -> np.ndarray:
120
+ return np.frombuffer(zlib.decompress(buffer), dtype=np.float32)
121
+
122
+
123
+ def cosine_similarity(a: bytes, b: bytes) -> float:
124
+ a = convert_vector(a)
125
+ b = convert_vector(b)
126
+ a = a / np.linalg.norm(a)
127
+ b = b / np.linalg.norm(b)
128
+ dopt = 0.5 * np.dot(a, b) + 0.5
129
+ return float(dopt)
130
+
131
+
132
+ def setup_db(connection: sqlite3.Connection):
133
+ sqlite3.register_adapter(np.ndarray, adapt_vector)
134
+ sqlite3.register_converter("VECTOR", convert_vector)
135
+ connection.create_function("sim", 2, cosine_similarity, deterministic=True)