Marc-Antoine Rondeau commited on
Commit
fb83544
Β·
1 Parent(s): 97aefb5

Moved schema to replace previous implementation

Browse files
buster/documents/sqlite.py DELETED
@@ -1,122 +0,0 @@
1
- import sqlite3
2
- import warnings
3
- import zlib
4
-
5
- import numpy as np
6
- import pandas as pd
7
-
8
- from buster.documents.base import DocumentsManager
9
-
10
- documents_table = """CREATE TABLE IF NOT EXISTS documents (
11
- id INTEGER PRIMARY KEY AUTOINCREMENT,
12
- source TEXT NOT NULL,
13
- title TEXT NOT NULL,
14
- url TEXT NOT NULL,
15
- content TEXT NOT NULL,
16
- n_tokens INTEGER,
17
- embedding BLOB,
18
- current INTEGER
19
- )"""
20
-
21
- qa_table = """CREATE TABLE IF NOT EXISTS qa (
22
- id INTEGER PRIMARY KEY AUTOINCREMENT,
23
- source TEXT NOT NULL,
24
- prompt TEXT NOT NULL,
25
- answer TEXT NOT NULL,
26
- document_id_1 INTEGER,
27
- document_id_2 INTEGER,
28
- document_id_3 INTEGER,
29
- label_question INTEGER,
30
- label_answer INTEGER,
31
- testset INTEGER,
32
- FOREIGN KEY (document_id_1) REFERENCES documents (id),
33
- FOREIGN KEY (document_id_2) REFERENCES documents (id),
34
- FOREIGN KEY (document_id_3) REFERENCES documents (id)
35
- )"""
36
-
37
-
38
- class DocumentsDB(DocumentsManager):
39
- """Simple SQLite database for storing documents and questions/answers.
40
-
41
- The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
42
- Questions/answers refer to the version of the document that was used at the time.
43
-
44
- Example:
45
- >>> db = DocumentsDB("/path/to/the/db.db")
46
- >>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
47
- >>> df = db.get_documents("source")
48
- """
49
-
50
- def __init__(self, filepath: str):
51
- self.db_path = filepath
52
- self.conn = sqlite3.connect(filepath)
53
- self.cursor = self.conn.cursor()
54
-
55
- self.__initialize()
56
-
57
- def __del__(self):
58
- self.conn.close()
59
-
60
- def __initialize(self):
61
- """Initialize the database."""
62
- self.cursor.execute(documents_table)
63
- self.cursor.execute(qa_table)
64
- self.conn.commit()
65
-
66
- def add(self, source: str, df: pd.DataFrame):
67
- """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
68
- df = df.copy()
69
-
70
- # Prepare the rows
71
- df["source"] = source
72
- df["current"] = 1
73
- columns = ["source", "title", "url", "content", "current"]
74
- if "embedding" in df.columns:
75
- columns.extend(
76
- [
77
- "n_tokens",
78
- "embedding",
79
- ]
80
- )
81
-
82
- # Check that the embeddings are float32
83
- if not df["embedding"].iloc[0].dtype == np.float32:
84
- warnings.warn(
85
- f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
86
- RuntimeWarning,
87
- )
88
- df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
89
-
90
- # ZLIB compress the embeddings
91
- df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
92
-
93
- data = df[columns].values.tolist()
94
-
95
- # Set `current` to 0 for all previous documents from that source
96
- self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
97
-
98
- # Insert the new documents
99
- insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
100
- self.cursor.executemany(insert_statement, data)
101
-
102
- self.conn.commit()
103
-
104
- def get_documents(self, source: str) -> pd.DataFrame:
105
- """Get all current documents from a given source."""
106
- # Execute the SQL statement and fetch the results
107
- if source is not None:
108
- results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
109
- else:
110
- results = self.cursor.execute("SELECT * FROM documents WHERE current = 1")
111
- rows = results.fetchall()
112
-
113
- # Convert the results to a pandas DataFrame
114
- df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
115
-
116
- # ZLIB decompress the embeddings
117
- df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist())
118
-
119
- # Drop the `current` column
120
- df.drop(columns=["current"], inplace=True)
121
-
122
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
buster/{db β†’ documents/sqlite}/__init__.py RENAMED
File without changes
buster/{db β†’ documents/sqlite}/backward.py RENAMED
@@ -2,15 +2,13 @@
2
 
3
  import argparse
4
  import itertools
 
5
  from typing import Iterable, NamedTuple
6
 
7
  import numpy as np
8
- import sqlite3
9
-
10
- from buster.db import DocumentsDB
11
-
12
- import buster.db.documents as dest
13
 
 
 
14
 
15
  IMPORT_QUERY = (
16
  r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
@@ -90,15 +88,14 @@ def main():
90
  db = DocumentsDB(args.destination)
91
 
92
  for source, content in get_documents(org):
93
- sid, vid = db.start_version(source)
94
  sections = (dest.Section(section.title, section.url, section.content) for section in content)
95
- db.add_sections(sid, vid, sections)
96
 
97
  size = max(args.size, get_max_size(org))
98
  for source, chunks in get_chunks(org):
99
  sid, vid = db.get_current_version(source)
100
- cid = db.add_chunking(sid, vid, size)
101
- db.add_chunks(sid, vid, cid, chunks)
102
  db.conn.commit()
103
 
104
  return
 
2
 
3
  import argparse
4
  import itertools
5
+ import sqlite3
6
  from typing import Iterable, NamedTuple
7
 
8
  import numpy as np
 
 
 
 
 
9
 
10
+ import buster.documents.sqlite.documents as dest
11
+ from buster.documents.sqlite import DocumentsDB
12
 
13
  IMPORT_QUERY = (
14
  r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
 
88
  db = DocumentsDB(args.destination)
89
 
90
  for source, content in get_documents(org):
91
+ # sid, vid = db.start_version(source)
92
  sections = (dest.Section(section.title, section.url, section.content) for section in content)
93
+ db.add_parse(source, sections)
94
 
95
  size = max(args.size, get_max_size(org))
96
  for source, chunks in get_chunks(org):
97
  sid, vid = db.get_current_version(source)
98
+ db.add_chunking(sid, vid, size, chunks)
 
99
  db.conn.commit()
100
 
101
  return
buster/{db β†’ documents/sqlite}/documents.py RENAMED
@@ -1,12 +1,15 @@
 
1
  import sqlite3
2
- from typing import Iterable, NamedTuple
3
  import warnings
4
  import zlib
 
 
5
 
6
  import numpy as np
7
  import pandas as pd
8
 
9
- import buster.db.schema as schema
 
10
 
11
 
12
  class Section(NamedTuple):
@@ -23,7 +26,7 @@ class Chunk(NamedTuple):
23
  emb: np.ndarray
24
 
25
 
26
- class DocumentsDB:
27
  """Simple SQLite database for storing documents and questions/answers.
28
 
29
  The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
@@ -36,13 +39,12 @@ class DocumentsDB:
36
  """
37
 
38
  def __init__(self, db_path: sqlite3.Connection | str):
39
- if isinstance(db_path, str):
40
  self.db_path = db_path
41
- self.conn = sqlite3.connect(db_path)
42
  else:
43
  self.db_path = None
44
  self.conn = db_path
45
- self.cursor = self.conn.cursor()
46
  schema.initialize_db(self.conn)
47
  schema.setup_db(self.conn)
48
 
@@ -51,6 +53,7 @@ class DocumentsDB:
51
  self.conn.close()
52
 
53
  def get_current_version(self, source: str) -> tuple[int, int]:
 
54
  cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
55
  row = cur.fetchone()
56
  if row is None:
@@ -59,6 +62,7 @@ class DocumentsDB:
59
  return sid, vid
60
 
61
  def get_source(self, source: str) -> int:
 
62
  cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
63
  row = cur.fetchone()
64
  if row is not None:
@@ -71,7 +75,8 @@ class DocumentsDB:
71
 
72
  return sid
73
 
74
- def start_version(self, source: str) -> tuple[int, int]:
 
75
  cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
76
  row = cur.fetchone()
77
  if row is None:
@@ -83,7 +88,9 @@ class DocumentsDB:
83
  self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
84
  return sid, vid
85
 
86
- def add_sections(self, sid: int, vid: int, sections: Iterable[Section]):
 
 
87
  values = (
88
  (sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
89
  for ind, section in enumerate(sections)
@@ -94,9 +101,10 @@ class DocumentsDB:
94
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
95
  values,
96
  )
97
- return
98
 
99
- def add_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
 
100
  self.conn.execute(
101
  "INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
102
  (size, overlap, strategy, sid, vid),
@@ -109,7 +117,9 @@ class DocumentsDB:
109
  (id,) = (id for id, in cur)
110
  return id
111
 
112
- def add_chunks(self, sid: int, vid: int, cid: int, sections: Iterable[Iterable[Chunk]]):
 
 
113
  chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
114
  values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
115
  self.conn.executemany(
@@ -118,51 +128,29 @@ class DocumentsDB:
118
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
119
  values,
120
  )
121
- return
122
-
123
- def write_documents(self, source: str, df: pd.DataFrame):
124
- """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
125
- df = df.copy()
126
-
127
- # Prepare the rows
128
- df["source"] = source
129
- df["current"] = 1
130
- columns = ["source", "title", "url", "content", "current"]
131
- if "embedding" in df.columns:
132
- columns.extend(
133
- [
134
- "n_tokens",
135
- "embedding",
136
- ]
137
- )
138
-
139
- # Check that the embeddings are float32
140
- if not df["embedding"].iloc[0].dtype == np.float32:
141
- warnings.warn(
142
- f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
143
- RuntimeWarning,
144
- )
145
- df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
146
-
147
- # ZLIB compress the embeddings
148
- df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
149
-
150
- data = df[columns].values.tolist()
151
-
152
- # Set `current` to 0 for all previous documents from that source
153
- self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
154
-
155
- # Insert the new documents
156
- insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
157
- self.cursor.executemany(insert_statement, data)
158
-
159
- self.conn.commit()
160
 
161
  def get_documents(self, source: str) -> pd.DataFrame:
162
  """Get all current documents from a given source."""
163
  # Execute the SQL statement and fetch the results
164
- results = self.cursor.execute("SELECT * FROM documents WHERE source = ?", (source,))
165
  rows = results.fetchall()
 
166
 
167
  # Convert the results to a pandas DataFrame
168
  df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
 
1
+ import itertools
2
  import sqlite3
 
3
  import warnings
4
  import zlib
5
+ from pathlib import Path
6
+ from typing import Iterable, NamedTuple
7
 
8
  import numpy as np
9
  import pandas as pd
10
 
11
+ import buster.documents.sqlite.schema as schema
12
+ from buster.documents.base import DocumentsManager
13
 
14
 
15
  class Section(NamedTuple):
 
26
  emb: np.ndarray
27
 
28
 
29
+ class DocumentsDB(DocumentsManager):
30
  """Simple SQLite database for storing documents and questions/answers.
31
 
32
  The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
 
39
  """
40
 
41
  def __init__(self, db_path: sqlite3.Connection | str):
42
+ if isinstance(db_path, (str, Path)):
43
  self.db_path = db_path
44
+ self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES)
45
  else:
46
  self.db_path = None
47
  self.conn = db_path
 
48
  schema.initialize_db(self.conn)
49
  schema.setup_db(self.conn)
50
 
 
53
  self.conn.close()
54
 
55
  def get_current_version(self, source: str) -> tuple[int, int]:
56
+ """Get the current version of a source."""
57
  cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
58
  row = cur.fetchone()
59
  if row is None:
 
62
  return sid, vid
63
 
64
  def get_source(self, source: str) -> int:
65
+ """Get the id of a source."""
66
  cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
67
  row = cur.fetchone()
68
  if row is not None:
 
75
 
76
  return sid
77
 
78
+ def new_version(self, source: str) -> tuple[int, int]:
79
+ """Create a new version for a source."""
80
  cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
81
  row = cur.fetchone()
82
  if row is None:
 
88
  self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
89
  return sid, vid
90
 
91
+ def add_parse(self, source: str, sections: Iterable[Section]) -> tuple[int, int]:
92
+ """Create a new version of a source filled with parsed sections."""
93
+ sid, vid = self.new_version(source)
94
  values = (
95
  (sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
96
  for ind, section in enumerate(sections)
 
101
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
102
  values,
103
  )
104
+ return sid, vid
105
 
106
+ def new_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
107
+ """Create a new chunking for a source."""
108
  self.conn.execute(
109
  "INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
110
  (size, overlap, strategy, sid, vid),
 
117
  (id,) = (id for id, in cur)
118
  return id
119
 
120
+ def add_chunking(self, sid: int, vid: int, size: int, sections: Iterable[Iterable[Chunk]]) -> int:
121
+ """Create a new chunking for a source, filled with chunks organized by section."""
122
+ cid = self.new_chunking(sid, vid, size)
123
  chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
124
  values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
125
  self.conn.executemany(
 
128
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
129
  values,
130
  )
131
+ return cid
132
+
133
+ def add(self, source: str, df: pd.DataFrame):
134
+ """Write all documents from the dataframe into the db as a new version."""
135
+ data = sorted(df.itertuples(), key=lambda chunk: (chunk.url, chunk.title))
136
+ sections = []
137
+ size = None
138
+ for (url, title), chunks in itertools.groupby(data, lambda chunk: (chunk.url, chunk.title)):
139
+ chunks = [Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks]
140
+ _size = max(len(chunk.content) for chunk in chunks)
141
+ size = max(_size, size or 0)
142
+ content = "".join(chunk.content for chunk in chunks)
143
+ sections.append((Section(title, url, content), chunks))
144
+
145
+ sid, vid = self.add_parse(source, (section for section, _ in sections))
146
+ self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def get_documents(self, source: str) -> pd.DataFrame:
149
  """Get all current documents from a given source."""
150
  # Execute the SQL statement and fetch the results
151
+ results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
152
  rows = results.fetchall()
153
+ print(rows[0])
154
 
155
  # Convert the results to a pandas DataFrame
156
  df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
buster/{db β†’ documents/sqlite}/schema.py RENAMED
@@ -1,9 +1,7 @@
 
1
  import zlib
2
 
3
-
4
  import numpy as np
5
- import sqlite3
6
-
7
 
8
  SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
9
  id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -131,5 +129,5 @@ def cosine_similarity(a: bytes, b: bytes) -> float:
131
 
132
  def setup_db(connection: sqlite3.Connection):
133
  sqlite3.register_adapter(np.ndarray, adapt_vector)
134
- sqlite3.register_converter("VECTOR", convert_vector)
135
  connection.create_function("sim", 2, cosine_similarity, deterministic=True)
 
1
+ import sqlite3
2
  import zlib
3
 
 
4
  import numpy as np
 
 
5
 
6
  SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
7
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 
129
 
130
  def setup_db(connection: sqlite3.Connection):
131
  sqlite3.register_adapter(np.ndarray, adapt_vector)
132
+ sqlite3.register_converter("vector", convert_vector)
133
  connection.create_function("sim", 2, cosine_similarity, deterministic=True)