Spaces:
Runtime error
Runtime error
File size: 5,710 Bytes
49b1fb3 0b4f7e4 49b1fb3 413b78d e112463 49b1fb3 05dabf4 e112463 49b1fb3 8756061 8ef8e62 90ae9dd 4dcc0d8 413b78d 8ef8e62 0ff46a1 49b1fb3 8756061 0ff46a1 8756061 49b1fb3 0b4f7e4 eec81fa 49b1fb3 90ae9dd 0b4f7e4 90ae9dd eec81fa 0b4f7e4 8756061 413b78d 0b4f7e4 413b78d 0b4f7e4 6e7e500 90ae9dd 0b4f7e4 49b1fb3 0b4f7e4 eec81fa 0b4f7e4 8ef8e62 49b1fb3 90ae9dd 0b4f7e4 49b1fb3 0b4f7e4 49b1fb3 90ae9dd 0b4f7e4 49b1fb3 6e7e500 49b1fb3 0ff46a1 0b4f7e4 90ae9dd 0ff46a1 6e7e500 90ae9dd 0b4f7e4 49b1fb3 eec81fa 49b1fb3 0b4f7e4 49b1fb3 0ff46a1 4dcc0d8 0b4f7e4 4dcc0d8 49b1fb3 0b4f7e4 4dcc0d8 e112463 05dabf4 e112463 05dabf4 e112463 4dcc0d8 e112463 05dabf4 e112463 4dcc0d8 e112463 49b1fb3 0ff46a1 4dcc0d8 49b1fb3 0b4f7e4 49b1fb3 0b4f7e4 e112463 0b4f7e4 4dcc0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import glob
import math
import os
import bs4
import pandas as pd
import tiktoken
from bs4 import BeautifulSoup
from openai.embeddings_utils import get_embedding
EMBEDDING_MODEL = "text-embedding-ada-002"
EMBEDDING_ENCODING = "cl100k_base" # this the encoding for text-embedding-ada-002
BASE_URL_MILA = "https://docs.mila.quebec/"
BASE_URL_ORION = "https://orion.readthedocs.io/en/stable/"
BASE_URL_PYTORCH = "https://pytorch.org/docs/stable/"
PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
def parse_section(nodes: list[bs4.element.NavigableString]) -> str:
section = []
for node in nodes:
if node.name == "table":
node_text = pd.read_html(node.prettify())[0].to_markdown(index=False, tablefmt="github")
else:
node_text = node.text
section.append(node_text)
section = "".join(section)[1:]
return section
def get_all_documents(
root_dir: str, base_url: str, min_section_length: int = 100, max_section_length: int = 2000
) -> pd.DataFrame:
"""Parse all HTML files in `root_dir`, and extract all sections.
Sections are broken into subsections if they are longer than `max_section_length`.
Sections correspond to `section` HTML tags that have a headerlink attached.
"""
files = glob.glob("**/*.html", root_dir=root_dir, recursive=True)
def get_all_subsections(soup: BeautifulSoup) -> tuple[list[str], list[str], list[str]]:
found = soup.find_all("a", href=True, class_="headerlink")
sections = []
urls = []
names = []
for section_found in found:
section_soup = section_found.parent.parent
section_href = section_soup.find_all("a", href=True, class_="headerlink")
# If sections has subsections, keep only the part before the first subsection
if len(section_href) > 1 and section_soup.section is not None:
section_siblings = list(section_soup.section.previous_siblings)[::-1]
section = parse_section(section_siblings)
else:
section = parse_section(section_soup.children)
# Remove special characters, plus newlines in some url and section names.
section = section.strip()
url = section_found["href"].strip().replace("\n", "")
name = section_found.parent.text.strip()[:-1].replace("\n", "")
# If text is too long, split into chunks of equal sizes
if len(section) > max_section_length:
n_chunks = math.ceil(len(section) / float(max_section_length))
separator_index = math.floor(len(section) / n_chunks)
section_chunks = [section[separator_index * i : separator_index * (i + 1)] for i in range(n_chunks)]
url_chunks = [url] * n_chunks
name_chunks = [name] * n_chunks
sections.extend(section_chunks)
urls.extend(url_chunks)
names.extend(name_chunks)
# If text is not too short, add in 1 chunk
elif len(section) > min_section_length:
sections.append(section)
urls.append(url)
names.append(name)
return sections, urls, names
sections = []
urls = []
names = []
for file in files:
filepath = os.path.join(root_dir, file)
with open(filepath, "r") as f:
source = f.read()
soup = BeautifulSoup(source, "html.parser")
sections_file, urls_file, names_file = get_all_subsections(soup)
sections.extend(sections_file)
urls_file = [base_url + file + url for url in urls_file]
urls.extend(urls_file)
names.extend(names_file)
documents_df = pd.DataFrame.from_dict({"name": names, "url": urls, "text": sections})
return documents_df
def get_file_extension(filepath: str) -> str:
return os.path.splitext(filepath)[1]
def write_documents(filepath: str, documents_df: pd.DataFrame):
ext = get_file_extension(filepath)
if ext == ".csv":
documents_df.to_csv(filepath, index=False)
elif ext in PICKLE_EXTENSIONS:
documents_df.to_pickle(filepath)
else:
raise ValueError(f"Unsupported format: {ext}.")
def read_documents(filepath: str) -> pd.DataFrame:
ext = get_file_extension(filepath)
if ext == ".csv":
return pd.read_csv(filepath)
elif ext in PICKLE_EXTENSIONS:
return pd.read_pickle(filepath)
else:
raise ValueError(f"Unsupported format: {ext}.")
def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
df["n_tokens"] = df.text.apply(lambda x: len(encoding.encode(x)))
return df
def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
df["embedding"] = df.text.apply(lambda x: get_embedding(x, engine=EMBEDDING_MODEL))
return df
def generate_embeddings(filepath: str, output_file: str) -> pd.DataFrame:
# Get all documents and precompute their embeddings
df = read_documents(filepath)
df = compute_n_tokens(df)
df = precompute_embeddings(df)
write_documents(output_file, df)
return df
if __name__ == "__main__":
root_dir = "/home/hadrien/perso/mila-docs/output/"
save_filepath = "data/documents.tar.gz"
# How to write
documents_df = get_all_documents(root_dir)
write_documents(save_filepath, documents_df)
# How to load
documents_df = read_documents(save_filepath)
# precompute the document embeddings
df = generate_embeddings(filepath=save_filepath, output_file="data/document_embeddings.tar.gz")
|