hbertrand jerpint commited on
Commit
1f22b14
·
unverified ·
1 Parent(s): 8b5fed9

Create SQLite db for documents (#46)

Browse files

* sqlite db

* isort

* tests

* PR

* tests

* change names

* put default empty string for source

* type warning

* change paths

* Fix tests

* add kwargs

---------

Co-authored-by: Jeremy Pinto <[email protected]>

.github/workflows/tests.yaml CHANGED
@@ -20,4 +20,4 @@ jobs:
20
  run: |
21
  python3 -m pip install --upgrade pip
22
  pip install -e .
23
- # pytest
 
20
  run: |
21
  python3 -m pip install --upgrade pip
22
  pip install -e .
23
+ pytest
buster/apps/gradio_app.ipynb CHANGED
@@ -14,7 +14,7 @@
14
  "from buster.chatbot import Chatbot, ChatbotConfig\n",
15
  "\n",
16
  "hf_transformers_cfg = ChatbotConfig(\n",
17
- " documents_file=\"../data/document_embeddings_hf_transformers.tar.gz\",\n",
18
  " unknown_prompt=\"This doesn't seem to be related to the huggingface library. I am not sure how to answer.\",\n",
19
  " embedding_model=\"text-embedding-ada-002\",\n",
20
  " top_k=3,\n",
@@ -123,7 +123,7 @@
123
  "name": "python",
124
  "nbconvert_exporter": "python",
125
  "pygments_lexer": "ipython3",
126
- "version": "3.9.12"
127
  },
128
  "vscode": {
129
  "interpreter": {
 
14
  "from buster.chatbot import Chatbot, ChatbotConfig\n",
15
  "\n",
16
  "hf_transformers_cfg = ChatbotConfig(\n",
17
+ " documents_file=\"../data/document_embeddings_huggingface.tar.gz\",\n",
18
  " unknown_prompt=\"This doesn't seem to be related to the huggingface library. I am not sure how to answer.\",\n",
19
  " embedding_model=\"text-embedding-ada-002\",\n",
20
  " top_k=3,\n",
 
123
  "name": "python",
124
  "nbconvert_exporter": "python",
125
  "pygments_lexer": "ipython3",
126
+ "version": "3.9.12 (main, Apr 5 2022, 01:52:34) \n[Clang 12.0.0 ]"
127
  },
128
  "vscode": {
129
  "interpreter": {
buster/apps/slackbot.py CHANGED
@@ -15,7 +15,7 @@ PYTORCH_CHANNEL = "C04MEK6N882"
15
  HF_TRANSFORMERS_CHANNEL = "C04NJNCJWHE"
16
 
17
  mila_doc_cfg = ChatbotConfig(
18
- documents_file="../data/document_embeddings.csv",
19
  unknown_prompt="This doesn't seem to be related to cluster usage.",
20
  embedding_model="text-embedding-ada-002",
21
  top_k=3,
@@ -51,7 +51,7 @@ mila_doc_cfg = ChatbotConfig(
51
  mila_doc_chatbot = Chatbot(mila_doc_cfg)
52
 
53
  orion_cfg = ChatbotConfig(
54
- documents_file="../data/document_embeddings_orion.csv",
55
  unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
56
  embedding_model="text-embedding-ada-002",
57
  top_k=3,
@@ -117,7 +117,7 @@ pytorch_cfg = ChatbotConfig(
117
  pytorch_chatbot = Chatbot(pytorch_cfg)
118
 
119
  hf_transformers_cfg = ChatbotConfig(
120
- documents_file="../data/document_embeddings_hf_transformers.tar.gz",
121
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
122
  embedding_model="text-embedding-ada-002",
123
  top_k=3,
 
15
  HF_TRANSFORMERS_CHANNEL = "C04NJNCJWHE"
16
 
17
  mila_doc_cfg = ChatbotConfig(
18
+ documents_file="../data/document_embeddings_mila.tar.gz",
19
  unknown_prompt="This doesn't seem to be related to cluster usage.",
20
  embedding_model="text-embedding-ada-002",
21
  top_k=3,
 
51
  mila_doc_chatbot = Chatbot(mila_doc_cfg)
52
 
53
  orion_cfg = ChatbotConfig(
54
+ documents_file="../data/document_embeddings_orion.tar.gz",
55
  unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
56
  embedding_model="text-embedding-ada-002",
57
  top_k=3,
 
117
  pytorch_chatbot = Chatbot(pytorch_cfg)
118
 
119
  hf_transformers_cfg = ChatbotConfig(
120
+ documents_file="../data/document_embeddings_huggingface.tar.gz",
121
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
122
  embedding_model="text-embedding-ada-002",
123
  top_k=3,
buster/chatbot.py CHANGED
@@ -123,7 +123,7 @@ class Chatbot:
123
 
124
  def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
125
  # gather the documents in one large plaintext variable
126
- documents_list = matched_documents.text.to_list()
127
  documents_str = " ".join(documents_list)
128
 
129
  # truncate the documents to fit
@@ -181,17 +181,17 @@ class Chatbot:
181
  """
182
 
183
  urls = matched_documents.url.to_list()
184
- names = matched_documents.name.to_list()
185
  similarities = matched_documents.similarity.to_list()
186
 
187
  response += f"{sep}{sep}📝 Here are the sources I used to answer your question:{sep}{sep}"
188
- for url, name, similarity in zip(urls, names, similarities):
189
  if format == "markdown":
190
- response += f"[🔗 {name}]({url}), relevance: {similarity:2.3f}{sep}"
191
  elif format == "html":
192
- response += f"<a href='{url}'>🔗 {name}</a>{sep}"
193
  elif format == "slack":
194
- response += f"<{url}|🔗 {name}>, relevance: {similarity:2.3f}{sep}"
195
  else:
196
  raise ValueError(f"{format} is not a valid URL format.")
197
 
 
123
 
124
  def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
125
  # gather the documents in one large plaintext variable
126
+ documents_list = matched_documents.content.to_list()
127
  documents_str = " ".join(documents_list)
128
 
129
  # truncate the documents to fit
 
181
  """
182
 
183
  urls = matched_documents.url.to_list()
184
+ titles = matched_documents.title.to_list()
185
  similarities = matched_documents.similarity.to_list()
186
 
187
  response += f"{sep}{sep}📝 Here are the sources I used to answer your question:{sep}{sep}"
188
+ for url, title, similarity in zip(urls, titles, similarities):
189
  if format == "markdown":
190
+ response += f"[🔗 {title}]({url}), relevance: {similarity:2.3f}{sep}"
191
  elif format == "html":
192
+ response += f"<a href='{url}'>🔗 {title}</a>{sep}"
193
  elif format == "slack":
194
+ response += f"<{url}|🔗 {title}>, relevance: {similarity:2.3f}{sep}"
195
  else:
196
  raise ValueError(f"{format} is not a valid URL format.")
197
 
buster/db.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import warnings
3
+ import zlib
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ documents_table = """CREATE TABLE IF NOT EXISTS documents (
9
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
10
+ source TEXT NOT NULL,
11
+ title TEXT NOT NULL,
12
+ url TEXT NOT NULL,
13
+ content TEXT NOT NULL,
14
+ n_tokens INTEGER,
15
+ embedding BLOB,
16
+ current INTEGER
17
+ )"""
18
+
19
+ qa_table = """CREATE TABLE IF NOT EXISTS qa (
20
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
21
+ source TEXT NOT NULL,
22
+ prompt TEXT NOT NULL,
23
+ answer TEXT NOT NULL,
24
+ document_id_1 INTEGER,
25
+ document_id_2 INTEGER,
26
+ document_id_3 INTEGER,
27
+ label_question INTEGER,
28
+ label_answer INTEGER,
29
+ testset INTEGER,
30
+ FOREIGN KEY (document_id_1) REFERENCES documents (id),
31
+ FOREIGN KEY (document_id_2) REFERENCES documents (id),
32
+ FOREIGN KEY (document_id_3) REFERENCES documents (id)
33
+ )"""
34
+
35
+
36
+ class DocumentsDB:
37
+ """Simple SQLite database for storing documents and questions/answers.
38
+
39
+ The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
40
+ Questions/answers refer to the version of the document that was used at the time.
41
+
42
+ Example:
43
+ >>> db = DocumentsDB("/path/to/the/db.db")
44
+ >>> db.write_documents("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
45
+ >>> df = db.get_documents("source")
46
+ """
47
+
48
+ def __init__(self, db_path):
49
+ self.db_path = db_path
50
+ self.conn = sqlite3.connect(db_path)
51
+ self.cursor = self.conn.cursor()
52
+
53
+ self.__initialize()
54
+
55
+ def __del__(self):
56
+ self.conn.close()
57
+
58
+ def __initialize(self):
59
+ """Initialize the database."""
60
+ self.cursor.execute(documents_table)
61
+ self.cursor.execute(qa_table)
62
+ self.conn.commit()
63
+
64
+ def write_documents(self, source: str, df: pd.DataFrame):
65
+ """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
66
+ df = df.copy()
67
+
68
+ # Prepare the rows
69
+ df["source"] = source
70
+ df["current"] = 1
71
+ columns = ["source", "title", "url", "content", "current"]
72
+ if "embedding" in df.columns:
73
+ columns.extend(
74
+ [
75
+ "n_tokens",
76
+ "embedding",
77
+ ]
78
+ )
79
+
80
+ # Check that the embeddings are float32
81
+ if not df["embedding"].iloc[0].dtype == np.float32:
82
+ warnings.warn(
83
+ f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
84
+ RuntimeWarning,
85
+ )
86
+ df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
87
+
88
+ # ZLIB compress the embeddings
89
+ df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
90
+
91
+ data = df[columns].values.tolist()
92
+
93
+ # Set `current` to 0 for all previous documents from that source
94
+ self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
95
+
96
+ # Insert the new documents
97
+ insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
98
+ self.cursor.executemany(insert_statement, data)
99
+
100
+ self.conn.commit()
101
+
102
+ def get_documents(self, source: str) -> pd.DataFrame:
103
+ """Get all current documents from a given source."""
104
+ # Execute the SQL statement and fetch the results
105
+ results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
106
+ rows = results.fetchall()
107
+
108
+ # Convert the results to a pandas DataFrame
109
+ df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
110
+
111
+ # ZLIB decompress the embeddings
112
+ df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist())
113
+
114
+ # Drop the `current` column
115
+ df.drop(columns=["current"], inplace=True)
116
+
117
+ return df
buster/docparser.py CHANGED
@@ -7,6 +7,7 @@ import tiktoken
7
  from bs4 import BeautifulSoup
8
  from openai.embeddings_utils import get_embedding
9
 
 
10
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
11
 
12
  EMBEDDING_MODEL = "text-embedding-ada-002"
@@ -19,22 +20,22 @@ PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".
19
  supported_docs = {
20
  "mila": {
21
  "base_url": "https://docs.mila.quebec/",
22
- "filename": "documents_mila.tar.gz",
23
  "parser": SphinxParser,
24
  },
25
  "orion": {
26
  "base_url": "https://orion.readthedocs.io/en/stable/",
27
- "filename": "documents_orion.tar.gz",
28
  "parser": SphinxParser,
29
  },
30
  "pytorch": {
31
  "base_url": "https://pytorch.org/docs/stable/",
32
- "filename": "documents_pytorch.tar.gz",
33
  "parser": SphinxParser,
34
  },
35
  "huggingface": {
36
  "base_url": "https://huggingface.co/docs/transformers/",
37
- "filename": "documents_huggingface.tar.gz",
38
  "parser": HuggingfaceParser,
39
  },
40
  }
@@ -66,7 +67,7 @@ def get_all_documents(
66
  urls.extend(urls_file)
67
  names.extend(names_file)
68
 
69
- documents_df = pd.DataFrame.from_dict({"name": names, "url": urls, "text": sections})
70
 
71
  return documents_df
72
 
@@ -75,46 +76,58 @@ def get_file_extension(filepath: str) -> str:
75
  return os.path.splitext(filepath)[1]
76
 
77
 
78
- def write_documents(filepath: str, documents_df: pd.DataFrame):
79
  ext = get_file_extension(filepath)
80
 
81
  if ext == ".csv":
82
  documents_df.to_csv(filepath, index=False)
83
  elif ext in PICKLE_EXTENSIONS:
84
  documents_df.to_pickle(filepath)
 
 
 
85
  else:
86
  raise ValueError(f"Unsupported format: {ext}.")
87
 
88
 
89
- def read_documents(filepath: str) -> pd.DataFrame:
90
  ext = get_file_extension(filepath)
91
 
92
  if ext == ".csv":
93
  df = pd.read_csv(filepath)
94
- df["embedding"] = df.embedding.apply(eval).apply(np.array)
95
- return df
 
96
  elif ext in PICKLE_EXTENSIONS:
97
- return pd.read_pickle(filepath)
 
 
 
 
 
 
98
  else:
99
  raise ValueError(f"Unsupported format: {ext}.")
100
 
 
 
101
 
102
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
103
  encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
104
  # TODO are there unexpected consequences of allowing endoftext?
105
- df["n_tokens"] = df.text.apply(lambda x: len(encoding.encode(x, allowed_special={"<|endoftext|>"})))
106
  return df
107
 
108
 
109
  def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
110
- df["embedding"] = df.text.apply(lambda x: get_embedding(x, engine=EMBEDDING_MODEL))
111
  return df
112
 
113
 
114
- def generate_embeddings(filepath: str, output_file: str) -> pd.DataFrame:
115
  # Get all documents and precompute their embeddings
116
- df = read_documents(filepath)
117
  df = compute_n_tokens(df)
118
  df = precompute_embeddings(df)
119
- write_documents(output_file, df)
120
  return df
 
7
  from bs4 import BeautifulSoup
8
  from openai.embeddings_utils import get_embedding
9
 
10
+ from buster.db import DocumentsDB
11
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
12
 
13
  EMBEDDING_MODEL = "text-embedding-ada-002"
 
20
  supported_docs = {
21
  "mila": {
22
  "base_url": "https://docs.mila.quebec/",
23
+ "filename": "documents_mila.csv",
24
  "parser": SphinxParser,
25
  },
26
  "orion": {
27
  "base_url": "https://orion.readthedocs.io/en/stable/",
28
+ "filename": "documents_orion.csv",
29
  "parser": SphinxParser,
30
  },
31
  "pytorch": {
32
  "base_url": "https://pytorch.org/docs/stable/",
33
+ "filename": "documents_pytorch.csv",
34
  "parser": SphinxParser,
35
  },
36
  "huggingface": {
37
  "base_url": "https://huggingface.co/docs/transformers/",
38
+ "filename": "documents_huggingface.csv",
39
  "parser": HuggingfaceParser,
40
  },
41
  }
 
67
  urls.extend(urls_file)
68
  names.extend(names_file)
69
 
70
+ documents_df = pd.DataFrame.from_dict({"title": names, "url": urls, "content": sections})
71
 
72
  return documents_df
73
 
 
76
  return os.path.splitext(filepath)[1]
77
 
78
 
79
+ def write_documents(filepath: str, documents_df: pd.DataFrame, source: str = ""):
80
  ext = get_file_extension(filepath)
81
 
82
  if ext == ".csv":
83
  documents_df.to_csv(filepath, index=False)
84
  elif ext in PICKLE_EXTENSIONS:
85
  documents_df.to_pickle(filepath)
86
+ elif ext == ".db":
87
+ db = DocumentsDB(filepath)
88
+ db.write_documents(source, documents_df)
89
  else:
90
  raise ValueError(f"Unsupported format: {ext}.")
91
 
92
 
93
+ def read_documents(filepath: str, source: str = "") -> pd.DataFrame:
94
  ext = get_file_extension(filepath)
95
 
96
  if ext == ".csv":
97
  df = pd.read_csv(filepath)
98
+
99
+ if "embedding" in df.columns:
100
+ df["embedding"] = df.embedding.apply(eval).apply(np.array)
101
  elif ext in PICKLE_EXTENSIONS:
102
+ df = pd.read_pickle(filepath)
103
+
104
+ if "embedding" in df.columns:
105
+ df["embedding"] = df.embedding.apply(np.array)
106
+ elif ext == ".db":
107
+ db = DocumentsDB(filepath)
108
+ df = db.get_documents(source)
109
  else:
110
  raise ValueError(f"Unsupported format: {ext}.")
111
 
112
+ return df
113
+
114
 
115
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
116
  encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
117
  # TODO are there unexpected consequences of allowing endoftext?
118
+ df["n_tokens"] = df.content.apply(lambda x: len(encoding.encode(x, allowed_special={"<|endoftext|>"})))
119
  return df
120
 
121
 
122
  def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
123
+ df["embedding"] = df.content.apply(lambda x: np.asarray(get_embedding(x, engine=EMBEDDING_MODEL), dtype=np.float32))
124
  return df
125
 
126
 
127
+ def generate_embeddings(filepath: str, output_file: str, source: str) -> pd.DataFrame:
128
  # Get all documents and precompute their embeddings
129
+ df = read_documents(filepath, source)
130
  df = compute_n_tokens(df)
131
  df = precompute_embeddings(df)
132
+ write_documents(filepath=output_file, documents_df=df, source=source)
133
  return df
db_to_csv.ipynb ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Example notebook on how to extract a source from the database and save it in another format"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "import os\n",
18
+ "\n",
19
+ "from buster.docparser import read_documents, write_documents"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# Path to the database\n",
29
+ "db_path = \"documents.db\"\n",
30
+ "\n",
31
+ "# Source to extract\n",
32
+ "target = \"pytorch\"\n",
33
+ "df = read_documents(db_path, target)\n",
34
+ "\n",
35
+ "# If you want to save it as tar.gz\n",
36
+ "filepath = os.path.join('buster/data/', f'document_embeddings_{target}.tar.gz')\n",
37
+ "write_documents(filepath, target, df)"
38
+ ]
39
+ }
40
+ ],
41
+ "metadata": {
42
+ "kernelspec": {
43
+ "display_name": "milabot",
44
+ "language": "python",
45
+ "name": "python3"
46
+ },
47
+ "language_info": {
48
+ "name": "python",
49
+ "version": "3.10.9"
50
+ },
51
+ "orig_nbformat": 4,
52
+ "vscode": {
53
+ "interpreter": {
54
+ "hash": "9db6f4b791ef587fd310257e87896b12053c9010399595f881592a25a8a29679"
55
+ }
56
+ }
57
+ },
58
+ "nbformat": 4,
59
+ "nbformat_minor": 2
60
+ }
requirements.txt CHANGED
@@ -6,6 +6,7 @@ tabulate
6
  tenacity
7
  tiktoken
8
  promptlayer
 
9
  openai
10
 
11
  # all openai[embeddings] deps, their list breaks our CI, see: https://github.com/openai/openai-python/issues/210
 
6
  tenacity
7
  tiktoken
8
  promptlayer
9
+ pytest
10
  openai
11
 
12
  # all openai[embeddings] deps, their list breaks our CI, see: https://github.com/openai/openai-python/issues/210
tests/test_db.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from buster.db import DocumentsDB
5
+
6
+
7
+ def test_write_read():
8
+ db = DocumentsDB(":memory:")
9
+
10
+ data = pd.DataFrame.from_dict(
11
+ {
12
+ "title": ["test"],
13
+ "url": ["http://url.com"],
14
+ "content": ["cool text"],
15
+ "embedding": [np.arange(10, dtype=np.float32) - 0.3],
16
+ "n_tokens": [10],
17
+ }
18
+ )
19
+ db.write_documents(source="test", df=data)
20
+
21
+ db_data = db.get_documents("test")
22
+
23
+ assert db_data["title"].iloc[0] == data["title"].iloc[0]
24
+ assert db_data["url"].iloc[0] == data["url"].iloc[0]
25
+ assert db_data["content"].iloc[0] == data["content"].iloc[0]
26
+ assert np.allclose(db_data["embedding"].iloc[0], data["embedding"].iloc[0])
27
+ assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
28
+
29
+
30
+ def test_write_write_read():
31
+ db = DocumentsDB(":memory:")
32
+
33
+ data_1 = pd.DataFrame.from_dict(
34
+ {
35
+ "title": ["test"],
36
+ "url": ["http://url.com"],
37
+ "content": ["cool text"],
38
+ "embedding": [np.arange(10, dtype=np.float32) - 0.3],
39
+ "n_tokens": [10],
40
+ }
41
+ )
42
+ db.write_documents(source="test", df=data_1)
43
+
44
+ data_2 = pd.DataFrame.from_dict(
45
+ {
46
+ "title": ["other"],
47
+ "url": ["http://url.com/page.html"],
48
+ "content": ["lorem ipsum"],
49
+ "embedding": [np.arange(20, dtype=np.float32) / 10 - 2.3],
50
+ "n_tokens": [20],
51
+ }
52
+ )
53
+ db.write_documents(source="test", df=data_2)
54
+
55
+ db_data = db.get_documents("test")
56
+
57
+ assert len(db_data) == len(data_2)
58
+ assert db_data["title"].iloc[0] == data_2["title"].iloc[0]
59
+ assert db_data["url"].iloc[0] == data_2["url"].iloc[0]
60
+ assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
61
+ assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
62
+ assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
tests/test_docparser.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from buster.docparser import generate_embeddings, read_documents, write_documents
5
+
6
+
7
+ def test_generate_embeddings(tmp_path, monkeypatch):
8
+ # Patch the get_embedding function to return a fixed embedding
9
+ monkeypatch.setattr("buster.docparser.get_embedding", lambda x, engine: [-0.005, 0.0018])
10
+
11
+ # Create fake data
12
+ data = pd.DataFrame.from_dict({"title": ["test"], "url": ["http://url.com"], "content": ["cool text"]})
13
+
14
+ # Write the data to a file
15
+ filepath = tmp_path / "test_document.csv"
16
+ write_documents(filepath=filepath, documents_df=data, source="test")
17
+
18
+ # Generate embeddings, store in a file
19
+ output_file = tmp_path / "test_document_embeddings.tar.gz"
20
+ df = generate_embeddings(filepath=filepath, output_file=output_file, source="test")
21
+
22
+ # Read the embeddings from the file
23
+ read_df = read_documents(output_file, "test")
24
+
25
+ # Check all the values are correct across the files
26
+ assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
27
+ assert df["url"].iloc[0] == data["url"].iloc[0] == read_df["url"].iloc[0]
28
+ assert df["content"].iloc[0] == data["content"].iloc[0] == read_df["content"].iloc[0]
29
+ assert np.allclose(df["embedding"].iloc[0], read_df["embedding"].iloc[0])
30
+ assert df["n_tokens"].iloc[0] == read_df["n_tokens"].iloc[0]