Spaces:
Runtime error
Runtime error
Create SQLite db for documents (#46)
Browse files* sqlite db
* isort
* tests
* PR
* tests
* change names
* put default empty string for source
* type warning
* change paths
* Fix tests
* add kwargs
---------
Co-authored-by: Jeremy Pinto <[email protected]>
- .github/workflows/tests.yaml +1 -1
- buster/apps/gradio_app.ipynb +2 -2
- buster/apps/slackbot.py +3 -3
- buster/chatbot.py +6 -6
- buster/db.py +117 -0
- buster/docparser.py +28 -15
- db_to_csv.ipynb +60 -0
- requirements.txt +1 -0
- tests/test_db.py +62 -0
- tests/test_docparser.py +30 -0
.github/workflows/tests.yaml
CHANGED
@@ -20,4 +20,4 @@ jobs:
|
|
20 |
run: |
|
21 |
python3 -m pip install --upgrade pip
|
22 |
pip install -e .
|
23 |
-
|
|
|
20 |
run: |
|
21 |
python3 -m pip install --upgrade pip
|
22 |
pip install -e .
|
23 |
+
pytest
|
buster/apps/gradio_app.ipynb
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
"from buster.chatbot import Chatbot, ChatbotConfig\n",
|
15 |
"\n",
|
16 |
"hf_transformers_cfg = ChatbotConfig(\n",
|
17 |
-
" documents_file=\"../data/
|
18 |
" unknown_prompt=\"This doesn't seem to be related to the huggingface library. I am not sure how to answer.\",\n",
|
19 |
" embedding_model=\"text-embedding-ada-002\",\n",
|
20 |
" top_k=3,\n",
|
@@ -123,7 +123,7 @@
|
|
123 |
"name": "python",
|
124 |
"nbconvert_exporter": "python",
|
125 |
"pygments_lexer": "ipython3",
|
126 |
-
"version": "3.9.12"
|
127 |
},
|
128 |
"vscode": {
|
129 |
"interpreter": {
|
|
|
14 |
"from buster.chatbot import Chatbot, ChatbotConfig\n",
|
15 |
"\n",
|
16 |
"hf_transformers_cfg = ChatbotConfig(\n",
|
17 |
+
" documents_file=\"../data/document_embeddings_huggingface.tar.gz\",\n",
|
18 |
" unknown_prompt=\"This doesn't seem to be related to the huggingface library. I am not sure how to answer.\",\n",
|
19 |
" embedding_model=\"text-embedding-ada-002\",\n",
|
20 |
" top_k=3,\n",
|
|
|
123 |
"name": "python",
|
124 |
"nbconvert_exporter": "python",
|
125 |
"pygments_lexer": "ipython3",
|
126 |
+
"version": "3.9.12 (main, Apr 5 2022, 01:52:34) \n[Clang 12.0.0 ]"
|
127 |
},
|
128 |
"vscode": {
|
129 |
"interpreter": {
|
buster/apps/slackbot.py
CHANGED
@@ -15,7 +15,7 @@ PYTORCH_CHANNEL = "C04MEK6N882"
|
|
15 |
HF_TRANSFORMERS_CHANNEL = "C04NJNCJWHE"
|
16 |
|
17 |
mila_doc_cfg = ChatbotConfig(
|
18 |
-
documents_file="../data/
|
19 |
unknown_prompt="This doesn't seem to be related to cluster usage.",
|
20 |
embedding_model="text-embedding-ada-002",
|
21 |
top_k=3,
|
@@ -51,7 +51,7 @@ mila_doc_cfg = ChatbotConfig(
|
|
51 |
mila_doc_chatbot = Chatbot(mila_doc_cfg)
|
52 |
|
53 |
orion_cfg = ChatbotConfig(
|
54 |
-
documents_file="../data/document_embeddings_orion.
|
55 |
unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
|
56 |
embedding_model="text-embedding-ada-002",
|
57 |
top_k=3,
|
@@ -117,7 +117,7 @@ pytorch_cfg = ChatbotConfig(
|
|
117 |
pytorch_chatbot = Chatbot(pytorch_cfg)
|
118 |
|
119 |
hf_transformers_cfg = ChatbotConfig(
|
120 |
-
documents_file="../data/
|
121 |
unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
|
122 |
embedding_model="text-embedding-ada-002",
|
123 |
top_k=3,
|
|
|
15 |
HF_TRANSFORMERS_CHANNEL = "C04NJNCJWHE"
|
16 |
|
17 |
mila_doc_cfg = ChatbotConfig(
|
18 |
+
documents_file="../data/document_embeddings_mila.tar.gz",
|
19 |
unknown_prompt="This doesn't seem to be related to cluster usage.",
|
20 |
embedding_model="text-embedding-ada-002",
|
21 |
top_k=3,
|
|
|
51 |
mila_doc_chatbot = Chatbot(mila_doc_cfg)
|
52 |
|
53 |
orion_cfg = ChatbotConfig(
|
54 |
+
documents_file="../data/document_embeddings_orion.tar.gz",
|
55 |
unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
|
56 |
embedding_model="text-embedding-ada-002",
|
57 |
top_k=3,
|
|
|
117 |
pytorch_chatbot = Chatbot(pytorch_cfg)
|
118 |
|
119 |
hf_transformers_cfg = ChatbotConfig(
|
120 |
+
documents_file="../data/document_embeddings_huggingface.tar.gz",
|
121 |
unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
|
122 |
embedding_model="text-embedding-ada-002",
|
123 |
top_k=3,
|
buster/chatbot.py
CHANGED
@@ -123,7 +123,7 @@ class Chatbot:
|
|
123 |
|
124 |
def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
|
125 |
# gather the documents in one large plaintext variable
|
126 |
-
documents_list = matched_documents.
|
127 |
documents_str = " ".join(documents_list)
|
128 |
|
129 |
# truncate the documents to fit
|
@@ -181,17 +181,17 @@ class Chatbot:
|
|
181 |
"""
|
182 |
|
183 |
urls = matched_documents.url.to_list()
|
184 |
-
|
185 |
similarities = matched_documents.similarity.to_list()
|
186 |
|
187 |
response += f"{sep}{sep}📝 Here are the sources I used to answer your question:{sep}{sep}"
|
188 |
-
for url,
|
189 |
if format == "markdown":
|
190 |
-
response += f"[🔗 {
|
191 |
elif format == "html":
|
192 |
-
response += f"<a href='{url}'>🔗 {
|
193 |
elif format == "slack":
|
194 |
-
response += f"<{url}|🔗 {
|
195 |
else:
|
196 |
raise ValueError(f"{format} is not a valid URL format.")
|
197 |
|
|
|
123 |
|
124 |
def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
|
125 |
# gather the documents in one large plaintext variable
|
126 |
+
documents_list = matched_documents.content.to_list()
|
127 |
documents_str = " ".join(documents_list)
|
128 |
|
129 |
# truncate the documents to fit
|
|
|
181 |
"""
|
182 |
|
183 |
urls = matched_documents.url.to_list()
|
184 |
+
titles = matched_documents.title.to_list()
|
185 |
similarities = matched_documents.similarity.to_list()
|
186 |
|
187 |
response += f"{sep}{sep}📝 Here are the sources I used to answer your question:{sep}{sep}"
|
188 |
+
for url, title, similarity in zip(urls, titles, similarities):
|
189 |
if format == "markdown":
|
190 |
+
response += f"[🔗 {title}]({url}), relevance: {similarity:2.3f}{sep}"
|
191 |
elif format == "html":
|
192 |
+
response += f"<a href='{url}'>🔗 {title}</a>{sep}"
|
193 |
elif format == "slack":
|
194 |
+
response += f"<{url}|🔗 {title}>, relevance: {similarity:2.3f}{sep}"
|
195 |
else:
|
196 |
raise ValueError(f"{format} is not a valid URL format.")
|
197 |
|
buster/db.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import warnings
|
3 |
+
import zlib
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
documents_table = """CREATE TABLE IF NOT EXISTS documents (
|
9 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
10 |
+
source TEXT NOT NULL,
|
11 |
+
title TEXT NOT NULL,
|
12 |
+
url TEXT NOT NULL,
|
13 |
+
content TEXT NOT NULL,
|
14 |
+
n_tokens INTEGER,
|
15 |
+
embedding BLOB,
|
16 |
+
current INTEGER
|
17 |
+
)"""
|
18 |
+
|
19 |
+
qa_table = """CREATE TABLE IF NOT EXISTS qa (
|
20 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
21 |
+
source TEXT NOT NULL,
|
22 |
+
prompt TEXT NOT NULL,
|
23 |
+
answer TEXT NOT NULL,
|
24 |
+
document_id_1 INTEGER,
|
25 |
+
document_id_2 INTEGER,
|
26 |
+
document_id_3 INTEGER,
|
27 |
+
label_question INTEGER,
|
28 |
+
label_answer INTEGER,
|
29 |
+
testset INTEGER,
|
30 |
+
FOREIGN KEY (document_id_1) REFERENCES documents (id),
|
31 |
+
FOREIGN KEY (document_id_2) REFERENCES documents (id),
|
32 |
+
FOREIGN KEY (document_id_3) REFERENCES documents (id)
|
33 |
+
)"""
|
34 |
+
|
35 |
+
|
36 |
+
class DocumentsDB:
|
37 |
+
"""Simple SQLite database for storing documents and questions/answers.
|
38 |
+
|
39 |
+
The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
|
40 |
+
Questions/answers refer to the version of the document that was used at the time.
|
41 |
+
|
42 |
+
Example:
|
43 |
+
>>> db = DocumentsDB("/path/to/the/db.db")
|
44 |
+
>>> db.write_documents("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
|
45 |
+
>>> df = db.get_documents("source")
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, db_path):
|
49 |
+
self.db_path = db_path
|
50 |
+
self.conn = sqlite3.connect(db_path)
|
51 |
+
self.cursor = self.conn.cursor()
|
52 |
+
|
53 |
+
self.__initialize()
|
54 |
+
|
55 |
+
def __del__(self):
|
56 |
+
self.conn.close()
|
57 |
+
|
58 |
+
def __initialize(self):
|
59 |
+
"""Initialize the database."""
|
60 |
+
self.cursor.execute(documents_table)
|
61 |
+
self.cursor.execute(qa_table)
|
62 |
+
self.conn.commit()
|
63 |
+
|
64 |
+
def write_documents(self, source: str, df: pd.DataFrame):
|
65 |
+
"""Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
|
66 |
+
df = df.copy()
|
67 |
+
|
68 |
+
# Prepare the rows
|
69 |
+
df["source"] = source
|
70 |
+
df["current"] = 1
|
71 |
+
columns = ["source", "title", "url", "content", "current"]
|
72 |
+
if "embedding" in df.columns:
|
73 |
+
columns.extend(
|
74 |
+
[
|
75 |
+
"n_tokens",
|
76 |
+
"embedding",
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
# Check that the embeddings are float32
|
81 |
+
if not df["embedding"].iloc[0].dtype == np.float32:
|
82 |
+
warnings.warn(
|
83 |
+
f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
|
84 |
+
RuntimeWarning,
|
85 |
+
)
|
86 |
+
df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
|
87 |
+
|
88 |
+
# ZLIB compress the embeddings
|
89 |
+
df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
|
90 |
+
|
91 |
+
data = df[columns].values.tolist()
|
92 |
+
|
93 |
+
# Set `current` to 0 for all previous documents from that source
|
94 |
+
self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
|
95 |
+
|
96 |
+
# Insert the new documents
|
97 |
+
insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
|
98 |
+
self.cursor.executemany(insert_statement, data)
|
99 |
+
|
100 |
+
self.conn.commit()
|
101 |
+
|
102 |
+
def get_documents(self, source: str) -> pd.DataFrame:
|
103 |
+
"""Get all current documents from a given source."""
|
104 |
+
# Execute the SQL statement and fetch the results
|
105 |
+
results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
|
106 |
+
rows = results.fetchall()
|
107 |
+
|
108 |
+
# Convert the results to a pandas DataFrame
|
109 |
+
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
110 |
+
|
111 |
+
# ZLIB decompress the embeddings
|
112 |
+
df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist())
|
113 |
+
|
114 |
+
# Drop the `current` column
|
115 |
+
df.drop(columns=["current"], inplace=True)
|
116 |
+
|
117 |
+
return df
|
buster/docparser.py
CHANGED
@@ -7,6 +7,7 @@ import tiktoken
|
|
7 |
from bs4 import BeautifulSoup
|
8 |
from openai.embeddings_utils import get_embedding
|
9 |
|
|
|
10 |
from buster.parser import HuggingfaceParser, Parser, SphinxParser
|
11 |
|
12 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
@@ -19,22 +20,22 @@ PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".
|
|
19 |
supported_docs = {
|
20 |
"mila": {
|
21 |
"base_url": "https://docs.mila.quebec/",
|
22 |
-
"filename": "documents_mila.
|
23 |
"parser": SphinxParser,
|
24 |
},
|
25 |
"orion": {
|
26 |
"base_url": "https://orion.readthedocs.io/en/stable/",
|
27 |
-
"filename": "documents_orion.
|
28 |
"parser": SphinxParser,
|
29 |
},
|
30 |
"pytorch": {
|
31 |
"base_url": "https://pytorch.org/docs/stable/",
|
32 |
-
"filename": "documents_pytorch.
|
33 |
"parser": SphinxParser,
|
34 |
},
|
35 |
"huggingface": {
|
36 |
"base_url": "https://huggingface.co/docs/transformers/",
|
37 |
-
"filename": "documents_huggingface.
|
38 |
"parser": HuggingfaceParser,
|
39 |
},
|
40 |
}
|
@@ -66,7 +67,7 @@ def get_all_documents(
|
|
66 |
urls.extend(urls_file)
|
67 |
names.extend(names_file)
|
68 |
|
69 |
-
documents_df = pd.DataFrame.from_dict({"
|
70 |
|
71 |
return documents_df
|
72 |
|
@@ -75,46 +76,58 @@ def get_file_extension(filepath: str) -> str:
|
|
75 |
return os.path.splitext(filepath)[1]
|
76 |
|
77 |
|
78 |
-
def write_documents(filepath: str, documents_df: pd.DataFrame):
|
79 |
ext = get_file_extension(filepath)
|
80 |
|
81 |
if ext == ".csv":
|
82 |
documents_df.to_csv(filepath, index=False)
|
83 |
elif ext in PICKLE_EXTENSIONS:
|
84 |
documents_df.to_pickle(filepath)
|
|
|
|
|
|
|
85 |
else:
|
86 |
raise ValueError(f"Unsupported format: {ext}.")
|
87 |
|
88 |
|
89 |
-
def read_documents(filepath: str) -> pd.DataFrame:
|
90 |
ext = get_file_extension(filepath)
|
91 |
|
92 |
if ext == ".csv":
|
93 |
df = pd.read_csv(filepath)
|
94 |
-
|
95 |
-
|
|
|
96 |
elif ext in PICKLE_EXTENSIONS:
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
else:
|
99 |
raise ValueError(f"Unsupported format: {ext}.")
|
100 |
|
|
|
|
|
101 |
|
102 |
def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
|
103 |
encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
|
104 |
# TODO are there unexpected consequences of allowing endoftext?
|
105 |
-
df["n_tokens"] = df.
|
106 |
return df
|
107 |
|
108 |
|
109 |
def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
|
110 |
-
df["embedding"] = df.
|
111 |
return df
|
112 |
|
113 |
|
114 |
-
def generate_embeddings(filepath: str, output_file: str) -> pd.DataFrame:
|
115 |
# Get all documents and precompute their embeddings
|
116 |
-
df = read_documents(filepath)
|
117 |
df = compute_n_tokens(df)
|
118 |
df = precompute_embeddings(df)
|
119 |
-
write_documents(output_file, df)
|
120 |
return df
|
|
|
7 |
from bs4 import BeautifulSoup
|
8 |
from openai.embeddings_utils import get_embedding
|
9 |
|
10 |
+
from buster.db import DocumentsDB
|
11 |
from buster.parser import HuggingfaceParser, Parser, SphinxParser
|
12 |
|
13 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
|
20 |
supported_docs = {
|
21 |
"mila": {
|
22 |
"base_url": "https://docs.mila.quebec/",
|
23 |
+
"filename": "documents_mila.csv",
|
24 |
"parser": SphinxParser,
|
25 |
},
|
26 |
"orion": {
|
27 |
"base_url": "https://orion.readthedocs.io/en/stable/",
|
28 |
+
"filename": "documents_orion.csv",
|
29 |
"parser": SphinxParser,
|
30 |
},
|
31 |
"pytorch": {
|
32 |
"base_url": "https://pytorch.org/docs/stable/",
|
33 |
+
"filename": "documents_pytorch.csv",
|
34 |
"parser": SphinxParser,
|
35 |
},
|
36 |
"huggingface": {
|
37 |
"base_url": "https://huggingface.co/docs/transformers/",
|
38 |
+
"filename": "documents_huggingface.csv",
|
39 |
"parser": HuggingfaceParser,
|
40 |
},
|
41 |
}
|
|
|
67 |
urls.extend(urls_file)
|
68 |
names.extend(names_file)
|
69 |
|
70 |
+
documents_df = pd.DataFrame.from_dict({"title": names, "url": urls, "content": sections})
|
71 |
|
72 |
return documents_df
|
73 |
|
|
|
76 |
return os.path.splitext(filepath)[1]
|
77 |
|
78 |
|
79 |
+
def write_documents(filepath: str, documents_df: pd.DataFrame, source: str = ""):
|
80 |
ext = get_file_extension(filepath)
|
81 |
|
82 |
if ext == ".csv":
|
83 |
documents_df.to_csv(filepath, index=False)
|
84 |
elif ext in PICKLE_EXTENSIONS:
|
85 |
documents_df.to_pickle(filepath)
|
86 |
+
elif ext == ".db":
|
87 |
+
db = DocumentsDB(filepath)
|
88 |
+
db.write_documents(source, documents_df)
|
89 |
else:
|
90 |
raise ValueError(f"Unsupported format: {ext}.")
|
91 |
|
92 |
|
93 |
+
def read_documents(filepath: str, source: str = "") -> pd.DataFrame:
|
94 |
ext = get_file_extension(filepath)
|
95 |
|
96 |
if ext == ".csv":
|
97 |
df = pd.read_csv(filepath)
|
98 |
+
|
99 |
+
if "embedding" in df.columns:
|
100 |
+
df["embedding"] = df.embedding.apply(eval).apply(np.array)
|
101 |
elif ext in PICKLE_EXTENSIONS:
|
102 |
+
df = pd.read_pickle(filepath)
|
103 |
+
|
104 |
+
if "embedding" in df.columns:
|
105 |
+
df["embedding"] = df.embedding.apply(np.array)
|
106 |
+
elif ext == ".db":
|
107 |
+
db = DocumentsDB(filepath)
|
108 |
+
df = db.get_documents(source)
|
109 |
else:
|
110 |
raise ValueError(f"Unsupported format: {ext}.")
|
111 |
|
112 |
+
return df
|
113 |
+
|
114 |
|
115 |
def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
|
116 |
encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
|
117 |
# TODO are there unexpected consequences of allowing endoftext?
|
118 |
+
df["n_tokens"] = df.content.apply(lambda x: len(encoding.encode(x, allowed_special={"<|endoftext|>"})))
|
119 |
return df
|
120 |
|
121 |
|
122 |
def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
|
123 |
+
df["embedding"] = df.content.apply(lambda x: np.asarray(get_embedding(x, engine=EMBEDDING_MODEL), dtype=np.float32))
|
124 |
return df
|
125 |
|
126 |
|
127 |
+
def generate_embeddings(filepath: str, output_file: str, source: str) -> pd.DataFrame:
|
128 |
# Get all documents and precompute their embeddings
|
129 |
+
df = read_documents(filepath, source)
|
130 |
df = compute_n_tokens(df)
|
131 |
df = precompute_embeddings(df)
|
132 |
+
write_documents(filepath=output_file, documents_df=df, source=source)
|
133 |
return df
|
db_to_csv.ipynb
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Example notebook on how to extract a source from the database and save it in another format"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"metadata": {},
|
15 |
+
"outputs": [],
|
16 |
+
"source": [
|
17 |
+
"import os\n",
|
18 |
+
"\n",
|
19 |
+
"from buster.docparser import read_documents, write_documents"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"# Path to the database\n",
|
29 |
+
"db_path = \"documents.db\"\n",
|
30 |
+
"\n",
|
31 |
+
"# Source to extract\n",
|
32 |
+
"target = \"pytorch\"\n",
|
33 |
+
"df = read_documents(db_path, target)\n",
|
34 |
+
"\n",
|
35 |
+
"# If you want to save it as tar.gz\n",
|
36 |
+
"filepath = os.path.join('buster/data/', f'document_embeddings_{target}.tar.gz')\n",
|
37 |
+
"write_documents(filepath, target, df)"
|
38 |
+
]
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"metadata": {
|
42 |
+
"kernelspec": {
|
43 |
+
"display_name": "milabot",
|
44 |
+
"language": "python",
|
45 |
+
"name": "python3"
|
46 |
+
},
|
47 |
+
"language_info": {
|
48 |
+
"name": "python",
|
49 |
+
"version": "3.10.9"
|
50 |
+
},
|
51 |
+
"orig_nbformat": 4,
|
52 |
+
"vscode": {
|
53 |
+
"interpreter": {
|
54 |
+
"hash": "9db6f4b791ef587fd310257e87896b12053c9010399595f881592a25a8a29679"
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"nbformat": 4,
|
59 |
+
"nbformat_minor": 2
|
60 |
+
}
|
requirements.txt
CHANGED
@@ -6,6 +6,7 @@ tabulate
|
|
6 |
tenacity
|
7 |
tiktoken
|
8 |
promptlayer
|
|
|
9 |
openai
|
10 |
|
11 |
# all openai[embeddings] deps, their list breaks our CI, see: https://github.com/openai/openai-python/issues/210
|
|
|
6 |
tenacity
|
7 |
tiktoken
|
8 |
promptlayer
|
9 |
+
pytest
|
10 |
openai
|
11 |
|
12 |
# all openai[embeddings] deps, their list breaks our CI, see: https://github.com/openai/openai-python/issues/210
|
tests/test_db.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from buster.db import DocumentsDB
|
5 |
+
|
6 |
+
|
7 |
+
def test_write_read():
|
8 |
+
db = DocumentsDB(":memory:")
|
9 |
+
|
10 |
+
data = pd.DataFrame.from_dict(
|
11 |
+
{
|
12 |
+
"title": ["test"],
|
13 |
+
"url": ["http://url.com"],
|
14 |
+
"content": ["cool text"],
|
15 |
+
"embedding": [np.arange(10, dtype=np.float32) - 0.3],
|
16 |
+
"n_tokens": [10],
|
17 |
+
}
|
18 |
+
)
|
19 |
+
db.write_documents(source="test", df=data)
|
20 |
+
|
21 |
+
db_data = db.get_documents("test")
|
22 |
+
|
23 |
+
assert db_data["title"].iloc[0] == data["title"].iloc[0]
|
24 |
+
assert db_data["url"].iloc[0] == data["url"].iloc[0]
|
25 |
+
assert db_data["content"].iloc[0] == data["content"].iloc[0]
|
26 |
+
assert np.allclose(db_data["embedding"].iloc[0], data["embedding"].iloc[0])
|
27 |
+
assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
|
28 |
+
|
29 |
+
|
30 |
+
def test_write_write_read():
|
31 |
+
db = DocumentsDB(":memory:")
|
32 |
+
|
33 |
+
data_1 = pd.DataFrame.from_dict(
|
34 |
+
{
|
35 |
+
"title": ["test"],
|
36 |
+
"url": ["http://url.com"],
|
37 |
+
"content": ["cool text"],
|
38 |
+
"embedding": [np.arange(10, dtype=np.float32) - 0.3],
|
39 |
+
"n_tokens": [10],
|
40 |
+
}
|
41 |
+
)
|
42 |
+
db.write_documents(source="test", df=data_1)
|
43 |
+
|
44 |
+
data_2 = pd.DataFrame.from_dict(
|
45 |
+
{
|
46 |
+
"title": ["other"],
|
47 |
+
"url": ["http://url.com/page.html"],
|
48 |
+
"content": ["lorem ipsum"],
|
49 |
+
"embedding": [np.arange(20, dtype=np.float32) / 10 - 2.3],
|
50 |
+
"n_tokens": [20],
|
51 |
+
}
|
52 |
+
)
|
53 |
+
db.write_documents(source="test", df=data_2)
|
54 |
+
|
55 |
+
db_data = db.get_documents("test")
|
56 |
+
|
57 |
+
assert len(db_data) == len(data_2)
|
58 |
+
assert db_data["title"].iloc[0] == data_2["title"].iloc[0]
|
59 |
+
assert db_data["url"].iloc[0] == data_2["url"].iloc[0]
|
60 |
+
assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
|
61 |
+
assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
|
62 |
+
assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
|
tests/test_docparser.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from buster.docparser import generate_embeddings, read_documents, write_documents
|
5 |
+
|
6 |
+
|
7 |
+
def test_generate_embeddings(tmp_path, monkeypatch):
|
8 |
+
# Patch the get_embedding function to return a fixed embedding
|
9 |
+
monkeypatch.setattr("buster.docparser.get_embedding", lambda x, engine: [-0.005, 0.0018])
|
10 |
+
|
11 |
+
# Create fake data
|
12 |
+
data = pd.DataFrame.from_dict({"title": ["test"], "url": ["http://url.com"], "content": ["cool text"]})
|
13 |
+
|
14 |
+
# Write the data to a file
|
15 |
+
filepath = tmp_path / "test_document.csv"
|
16 |
+
write_documents(filepath=filepath, documents_df=data, source="test")
|
17 |
+
|
18 |
+
# Generate embeddings, store in a file
|
19 |
+
output_file = tmp_path / "test_document_embeddings.tar.gz"
|
20 |
+
df = generate_embeddings(filepath=filepath, output_file=output_file, source="test")
|
21 |
+
|
22 |
+
# Read the embeddings from the file
|
23 |
+
read_df = read_documents(output_file, "test")
|
24 |
+
|
25 |
+
# Check all the values are correct across the files
|
26 |
+
assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
|
27 |
+
assert df["url"].iloc[0] == data["url"].iloc[0] == read_df["url"].iloc[0]
|
28 |
+
assert df["content"].iloc[0] == data["content"].iloc[0] == read_df["content"].iloc[0]
|
29 |
+
assert np.allclose(df["embedding"].iloc[0], read_df["embedding"].iloc[0])
|
30 |
+
assert df["n_tokens"].iloc[0] == read_df["n_tokens"].iloc[0]
|