Spaces:
Runtime error
Runtime error
Marc-Antoine Rondeau
commited on
Commit
Β·
fb83544
1
Parent(s):
97aefb5
Moved schema to replace previous implementation
Browse files
buster/documents/sqlite.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
import sqlite3
|
2 |
-
import warnings
|
3 |
-
import zlib
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
|
8 |
-
from buster.documents.base import DocumentsManager
|
9 |
-
|
10 |
-
documents_table = """CREATE TABLE IF NOT EXISTS documents (
|
11 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
12 |
-
source TEXT NOT NULL,
|
13 |
-
title TEXT NOT NULL,
|
14 |
-
url TEXT NOT NULL,
|
15 |
-
content TEXT NOT NULL,
|
16 |
-
n_tokens INTEGER,
|
17 |
-
embedding BLOB,
|
18 |
-
current INTEGER
|
19 |
-
)"""
|
20 |
-
|
21 |
-
qa_table = """CREATE TABLE IF NOT EXISTS qa (
|
22 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
23 |
-
source TEXT NOT NULL,
|
24 |
-
prompt TEXT NOT NULL,
|
25 |
-
answer TEXT NOT NULL,
|
26 |
-
document_id_1 INTEGER,
|
27 |
-
document_id_2 INTEGER,
|
28 |
-
document_id_3 INTEGER,
|
29 |
-
label_question INTEGER,
|
30 |
-
label_answer INTEGER,
|
31 |
-
testset INTEGER,
|
32 |
-
FOREIGN KEY (document_id_1) REFERENCES documents (id),
|
33 |
-
FOREIGN KEY (document_id_2) REFERENCES documents (id),
|
34 |
-
FOREIGN KEY (document_id_3) REFERENCES documents (id)
|
35 |
-
)"""
|
36 |
-
|
37 |
-
|
38 |
-
class DocumentsDB(DocumentsManager):
|
39 |
-
"""Simple SQLite database for storing documents and questions/answers.
|
40 |
-
|
41 |
-
The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
|
42 |
-
Questions/answers refer to the version of the document that was used at the time.
|
43 |
-
|
44 |
-
Example:
|
45 |
-
>>> db = DocumentsDB("/path/to/the/db.db")
|
46 |
-
>>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
|
47 |
-
>>> df = db.get_documents("source")
|
48 |
-
"""
|
49 |
-
|
50 |
-
def __init__(self, filepath: str):
|
51 |
-
self.db_path = filepath
|
52 |
-
self.conn = sqlite3.connect(filepath)
|
53 |
-
self.cursor = self.conn.cursor()
|
54 |
-
|
55 |
-
self.__initialize()
|
56 |
-
|
57 |
-
def __del__(self):
|
58 |
-
self.conn.close()
|
59 |
-
|
60 |
-
def __initialize(self):
|
61 |
-
"""Initialize the database."""
|
62 |
-
self.cursor.execute(documents_table)
|
63 |
-
self.cursor.execute(qa_table)
|
64 |
-
self.conn.commit()
|
65 |
-
|
66 |
-
def add(self, source: str, df: pd.DataFrame):
|
67 |
-
"""Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
|
68 |
-
df = df.copy()
|
69 |
-
|
70 |
-
# Prepare the rows
|
71 |
-
df["source"] = source
|
72 |
-
df["current"] = 1
|
73 |
-
columns = ["source", "title", "url", "content", "current"]
|
74 |
-
if "embedding" in df.columns:
|
75 |
-
columns.extend(
|
76 |
-
[
|
77 |
-
"n_tokens",
|
78 |
-
"embedding",
|
79 |
-
]
|
80 |
-
)
|
81 |
-
|
82 |
-
# Check that the embeddings are float32
|
83 |
-
if not df["embedding"].iloc[0].dtype == np.float32:
|
84 |
-
warnings.warn(
|
85 |
-
f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
|
86 |
-
RuntimeWarning,
|
87 |
-
)
|
88 |
-
df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
|
89 |
-
|
90 |
-
# ZLIB compress the embeddings
|
91 |
-
df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
|
92 |
-
|
93 |
-
data = df[columns].values.tolist()
|
94 |
-
|
95 |
-
# Set `current` to 0 for all previous documents from that source
|
96 |
-
self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
|
97 |
-
|
98 |
-
# Insert the new documents
|
99 |
-
insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
|
100 |
-
self.cursor.executemany(insert_statement, data)
|
101 |
-
|
102 |
-
self.conn.commit()
|
103 |
-
|
104 |
-
def get_documents(self, source: str) -> pd.DataFrame:
|
105 |
-
"""Get all current documents from a given source."""
|
106 |
-
# Execute the SQL statement and fetch the results
|
107 |
-
if source is not None:
|
108 |
-
results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
|
109 |
-
else:
|
110 |
-
results = self.cursor.execute("SELECT * FROM documents WHERE current = 1")
|
111 |
-
rows = results.fetchall()
|
112 |
-
|
113 |
-
# Convert the results to a pandas DataFrame
|
114 |
-
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
115 |
-
|
116 |
-
# ZLIB decompress the embeddings
|
117 |
-
df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist())
|
118 |
-
|
119 |
-
# Drop the `current` column
|
120 |
-
df.drop(columns=["current"], inplace=True)
|
121 |
-
|
122 |
-
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buster/{db β documents/sqlite}/__init__.py
RENAMED
File without changes
|
buster/{db β documents/sqlite}/backward.py
RENAMED
@@ -2,15 +2,13 @@
|
|
2 |
|
3 |
import argparse
|
4 |
import itertools
|
|
|
5 |
from typing import Iterable, NamedTuple
|
6 |
|
7 |
import numpy as np
|
8 |
-
import sqlite3
|
9 |
-
|
10 |
-
from buster.db import DocumentsDB
|
11 |
-
|
12 |
-
import buster.db.documents as dest
|
13 |
|
|
|
|
|
14 |
|
15 |
IMPORT_QUERY = (
|
16 |
r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
|
@@ -90,15 +88,14 @@ def main():
|
|
90 |
db = DocumentsDB(args.destination)
|
91 |
|
92 |
for source, content in get_documents(org):
|
93 |
-
sid, vid = db.start_version(source)
|
94 |
sections = (dest.Section(section.title, section.url, section.content) for section in content)
|
95 |
-
db.
|
96 |
|
97 |
size = max(args.size, get_max_size(org))
|
98 |
for source, chunks in get_chunks(org):
|
99 |
sid, vid = db.get_current_version(source)
|
100 |
-
|
101 |
-
db.add_chunks(sid, vid, cid, chunks)
|
102 |
db.conn.commit()
|
103 |
|
104 |
return
|
|
|
2 |
|
3 |
import argparse
|
4 |
import itertools
|
5 |
+
import sqlite3
|
6 |
from typing import Iterable, NamedTuple
|
7 |
|
8 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
import buster.documents.sqlite.documents as dest
|
11 |
+
from buster.documents.sqlite import DocumentsDB
|
12 |
|
13 |
IMPORT_QUERY = (
|
14 |
r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
|
|
|
88 |
db = DocumentsDB(args.destination)
|
89 |
|
90 |
for source, content in get_documents(org):
|
91 |
+
# sid, vid = db.start_version(source)
|
92 |
sections = (dest.Section(section.title, section.url, section.content) for section in content)
|
93 |
+
db.add_parse(source, sections)
|
94 |
|
95 |
size = max(args.size, get_max_size(org))
|
96 |
for source, chunks in get_chunks(org):
|
97 |
sid, vid = db.get_current_version(source)
|
98 |
+
db.add_chunking(sid, vid, size, chunks)
|
|
|
99 |
db.conn.commit()
|
100 |
|
101 |
return
|
buster/{db β documents/sqlite}/documents.py
RENAMED
@@ -1,12 +1,15 @@
|
|
|
|
1 |
import sqlite3
|
2 |
-
from typing import Iterable, NamedTuple
|
3 |
import warnings
|
4 |
import zlib
|
|
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
|
9 |
-
import buster.
|
|
|
10 |
|
11 |
|
12 |
class Section(NamedTuple):
|
@@ -23,7 +26,7 @@ class Chunk(NamedTuple):
|
|
23 |
emb: np.ndarray
|
24 |
|
25 |
|
26 |
-
class DocumentsDB:
|
27 |
"""Simple SQLite database for storing documents and questions/answers.
|
28 |
|
29 |
The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
|
@@ -36,13 +39,12 @@ class DocumentsDB:
|
|
36 |
"""
|
37 |
|
38 |
def __init__(self, db_path: sqlite3.Connection | str):
|
39 |
-
if isinstance(db_path, str):
|
40 |
self.db_path = db_path
|
41 |
-
self.conn = sqlite3.connect(db_path)
|
42 |
else:
|
43 |
self.db_path = None
|
44 |
self.conn = db_path
|
45 |
-
self.cursor = self.conn.cursor()
|
46 |
schema.initialize_db(self.conn)
|
47 |
schema.setup_db(self.conn)
|
48 |
|
@@ -51,6 +53,7 @@ class DocumentsDB:
|
|
51 |
self.conn.close()
|
52 |
|
53 |
def get_current_version(self, source: str) -> tuple[int, int]:
|
|
|
54 |
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
|
55 |
row = cur.fetchone()
|
56 |
if row is None:
|
@@ -59,6 +62,7 @@ class DocumentsDB:
|
|
59 |
return sid, vid
|
60 |
|
61 |
def get_source(self, source: str) -> int:
|
|
|
62 |
cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
|
63 |
row = cur.fetchone()
|
64 |
if row is not None:
|
@@ -71,7 +75,8 @@ class DocumentsDB:
|
|
71 |
|
72 |
return sid
|
73 |
|
74 |
-
def
|
|
|
75 |
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
|
76 |
row = cur.fetchone()
|
77 |
if row is None:
|
@@ -83,7 +88,9 @@ class DocumentsDB:
|
|
83 |
self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
|
84 |
return sid, vid
|
85 |
|
86 |
-
def
|
|
|
|
|
87 |
values = (
|
88 |
(sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
|
89 |
for ind, section in enumerate(sections)
|
@@ -94,9 +101,10 @@ class DocumentsDB:
|
|
94 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
95 |
values,
|
96 |
)
|
97 |
-
return
|
98 |
|
99 |
-
def
|
|
|
100 |
self.conn.execute(
|
101 |
"INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
|
102 |
(size, overlap, strategy, sid, vid),
|
@@ -109,7 +117,9 @@ class DocumentsDB:
|
|
109 |
(id,) = (id for id, in cur)
|
110 |
return id
|
111 |
|
112 |
-
def
|
|
|
|
|
113 |
chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
|
114 |
values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
|
115 |
self.conn.executemany(
|
@@ -118,51 +128,29 @@ class DocumentsDB:
|
|
118 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
119 |
values,
|
120 |
)
|
121 |
-
return
|
122 |
-
|
123 |
-
def
|
124 |
-
"""Write all documents from the dataframe into the db
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
)
|
138 |
-
|
139 |
-
# Check that the embeddings are float32
|
140 |
-
if not df["embedding"].iloc[0].dtype == np.float32:
|
141 |
-
warnings.warn(
|
142 |
-
f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
|
143 |
-
RuntimeWarning,
|
144 |
-
)
|
145 |
-
df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
|
146 |
-
|
147 |
-
# ZLIB compress the embeddings
|
148 |
-
df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
|
149 |
-
|
150 |
-
data = df[columns].values.tolist()
|
151 |
-
|
152 |
-
# Set `current` to 0 for all previous documents from that source
|
153 |
-
self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
|
154 |
-
|
155 |
-
# Insert the new documents
|
156 |
-
insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
|
157 |
-
self.cursor.executemany(insert_statement, data)
|
158 |
-
|
159 |
-
self.conn.commit()
|
160 |
|
161 |
def get_documents(self, source: str) -> pd.DataFrame:
|
162 |
"""Get all current documents from a given source."""
|
163 |
# Execute the SQL statement and fetch the results
|
164 |
-
results = self.
|
165 |
rows = results.fetchall()
|
|
|
166 |
|
167 |
# Convert the results to a pandas DataFrame
|
168 |
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
|
|
1 |
+
import itertools
|
2 |
import sqlite3
|
|
|
3 |
import warnings
|
4 |
import zlib
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Iterable, NamedTuple
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
10 |
|
11 |
+
import buster.documents.sqlite.schema as schema
|
12 |
+
from buster.documents.base import DocumentsManager
|
13 |
|
14 |
|
15 |
class Section(NamedTuple):
|
|
|
26 |
emb: np.ndarray
|
27 |
|
28 |
|
29 |
+
class DocumentsDB(DocumentsManager):
|
30 |
"""Simple SQLite database for storing documents and questions/answers.
|
31 |
|
32 |
The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
|
|
|
39 |
"""
|
40 |
|
41 |
def __init__(self, db_path: sqlite3.Connection | str):
|
42 |
+
if isinstance(db_path, (str, Path)):
|
43 |
self.db_path = db_path
|
44 |
+
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES)
|
45 |
else:
|
46 |
self.db_path = None
|
47 |
self.conn = db_path
|
|
|
48 |
schema.initialize_db(self.conn)
|
49 |
schema.setup_db(self.conn)
|
50 |
|
|
|
53 |
self.conn.close()
|
54 |
|
55 |
def get_current_version(self, source: str) -> tuple[int, int]:
|
56 |
+
"""Get the current version of a source."""
|
57 |
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
|
58 |
row = cur.fetchone()
|
59 |
if row is None:
|
|
|
62 |
return sid, vid
|
63 |
|
64 |
def get_source(self, source: str) -> int:
|
65 |
+
"""Get the id of a source."""
|
66 |
cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
|
67 |
row = cur.fetchone()
|
68 |
if row is not None:
|
|
|
75 |
|
76 |
return sid
|
77 |
|
78 |
+
def new_version(self, source: str) -> tuple[int, int]:
|
79 |
+
"""Create a new version for a source."""
|
80 |
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
|
81 |
row = cur.fetchone()
|
82 |
if row is None:
|
|
|
88 |
self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
|
89 |
return sid, vid
|
90 |
|
91 |
+
def add_parse(self, source: str, sections: Iterable[Section]) -> tuple[int, int]:
|
92 |
+
"""Create a new version of a source filled with parsed sections."""
|
93 |
+
sid, vid = self.new_version(source)
|
94 |
values = (
|
95 |
(sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
|
96 |
for ind, section in enumerate(sections)
|
|
|
101 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
102 |
values,
|
103 |
)
|
104 |
+
return sid, vid
|
105 |
|
106 |
+
def new_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
|
107 |
+
"""Create a new chunking for a source."""
|
108 |
self.conn.execute(
|
109 |
"INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
|
110 |
(size, overlap, strategy, sid, vid),
|
|
|
117 |
(id,) = (id for id, in cur)
|
118 |
return id
|
119 |
|
120 |
+
def add_chunking(self, sid: int, vid: int, size: int, sections: Iterable[Iterable[Chunk]]) -> int:
|
121 |
+
"""Create a new chunking for a source, filled with chunks organized by section."""
|
122 |
+
cid = self.new_chunking(sid, vid, size)
|
123 |
chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
|
124 |
values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
|
125 |
self.conn.executemany(
|
|
|
128 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
129 |
values,
|
130 |
)
|
131 |
+
return cid
|
132 |
+
|
133 |
+
def add(self, source: str, df: pd.DataFrame):
|
134 |
+
"""Write all documents from the dataframe into the db as a new version."""
|
135 |
+
data = sorted(df.itertuples(), key=lambda chunk: (chunk.url, chunk.title))
|
136 |
+
sections = []
|
137 |
+
size = None
|
138 |
+
for (url, title), chunks in itertools.groupby(data, lambda chunk: (chunk.url, chunk.title)):
|
139 |
+
chunks = [Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks]
|
140 |
+
_size = max(len(chunk.content) for chunk in chunks)
|
141 |
+
size = max(_size, size or 0)
|
142 |
+
content = "".join(chunk.content for chunk in chunks)
|
143 |
+
sections.append((Section(title, url, content), chunks))
|
144 |
+
|
145 |
+
sid, vid = self.add_parse(source, (section for section, _ in sections))
|
146 |
+
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
def get_documents(self, source: str) -> pd.DataFrame:
|
149 |
"""Get all current documents from a given source."""
|
150 |
# Execute the SQL statement and fetch the results
|
151 |
+
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
|
152 |
rows = results.fetchall()
|
153 |
+
print(rows[0])
|
154 |
|
155 |
# Convert the results to a pandas DataFrame
|
156 |
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
buster/{db β documents/sqlite}/schema.py
RENAMED
@@ -1,9 +1,7 @@
|
|
|
|
1 |
import zlib
|
2 |
|
3 |
-
|
4 |
import numpy as np
|
5 |
-
import sqlite3
|
6 |
-
|
7 |
|
8 |
SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
|
9 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
@@ -131,5 +129,5 @@ def cosine_similarity(a: bytes, b: bytes) -> float:
|
|
131 |
|
132 |
def setup_db(connection: sqlite3.Connection):
|
133 |
sqlite3.register_adapter(np.ndarray, adapt_vector)
|
134 |
-
sqlite3.register_converter("
|
135 |
connection.create_function("sim", 2, cosine_similarity, deterministic=True)
|
|
|
1 |
+
import sqlite3
|
2 |
import zlib
|
3 |
|
|
|
4 |
import numpy as np
|
|
|
|
|
5 |
|
6 |
SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
|
7 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
129 |
|
130 |
def setup_db(connection: sqlite3.Connection):
|
131 |
sqlite3.register_adapter(np.ndarray, adapt_vector)
|
132 |
+
sqlite3.register_converter("vector", convert_vector)
|
133 |
connection.create_function("sim", 2, cosine_similarity, deterministic=True)
|