File size: 5,710 Bytes
49b1fb3
0b4f7e4
49b1fb3
 
413b78d
e112463
 
49b1fb3
05dabf4
e112463
 
 
49b1fb3
 
8756061
 
8ef8e62
90ae9dd
 
4dcc0d8
 
 
413b78d
 
 
 
 
 
 
 
 
 
 
 
 
8ef8e62
 
 
0ff46a1
 
49b1fb3
8756061
0ff46a1
8756061
49b1fb3
0b4f7e4
eec81fa
49b1fb3
 
90ae9dd
0b4f7e4
90ae9dd
 
eec81fa
0b4f7e4
 
8756061
413b78d
 
0b4f7e4
413b78d
0b4f7e4
6e7e500
 
 
 
90ae9dd
0b4f7e4
49b1fb3
0b4f7e4
 
 
eec81fa
0b4f7e4
 
 
 
 
 
8ef8e62
 
49b1fb3
90ae9dd
0b4f7e4
49b1fb3
0b4f7e4
49b1fb3
 
90ae9dd
0b4f7e4
49b1fb3
 
6e7e500
 
49b1fb3
0ff46a1
0b4f7e4
90ae9dd
0ff46a1
6e7e500
90ae9dd
 
0b4f7e4
49b1fb3
eec81fa
49b1fb3
0b4f7e4
49b1fb3
0ff46a1
4dcc0d8
 
 
 
0b4f7e4
4dcc0d8
 
 
 
 
 
 
 
49b1fb3
 
0b4f7e4
4dcc0d8
 
 
 
 
 
 
 
e112463
 
 
 
05dabf4
e112463
 
 
 
05dabf4
e112463
 
 
4dcc0d8
e112463
05dabf4
e112463
 
4dcc0d8
e112463
 
 
49b1fb3
0ff46a1
4dcc0d8
49b1fb3
 
0b4f7e4
 
49b1fb3
 
0b4f7e4
e112463
0b4f7e4
4dcc0d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import glob
import math
import os

import bs4
import pandas as pd
import tiktoken
from bs4 import BeautifulSoup
from openai.embeddings_utils import get_embedding

EMBEDDING_MODEL = "text-embedding-ada-002"
EMBEDDING_ENCODING = "cl100k_base"  # this the encoding for text-embedding-ada-002


BASE_URL_MILA = "https://docs.mila.quebec/"
BASE_URL_ORION = "https://orion.readthedocs.io/en/stable/"
BASE_URL_PYTORCH = "https://pytorch.org/docs/stable/"


PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]


def parse_section(nodes: list[bs4.element.NavigableString]) -> str:
    section = []
    for node in nodes:
        if node.name == "table":
            node_text = pd.read_html(node.prettify())[0].to_markdown(index=False, tablefmt="github")
        else:
            node_text = node.text
        section.append(node_text)
    section = "".join(section)[1:]

    return section


def get_all_documents(
    root_dir: str, base_url: str, min_section_length: int = 100, max_section_length: int = 2000
) -> pd.DataFrame:
    """Parse all HTML files in `root_dir`, and extract all sections.

    Sections are broken into subsections if they are longer than `max_section_length`.
    Sections correspond to `section` HTML tags that have a headerlink attached.
    """
    files = glob.glob("**/*.html", root_dir=root_dir, recursive=True)

    def get_all_subsections(soup: BeautifulSoup) -> tuple[list[str], list[str], list[str]]:
        found = soup.find_all("a", href=True, class_="headerlink")

        sections = []
        urls = []
        names = []
        for section_found in found:
            section_soup = section_found.parent.parent
            section_href = section_soup.find_all("a", href=True, class_="headerlink")

            # If sections has subsections, keep only the part before the first subsection
            if len(section_href) > 1 and section_soup.section is not None:
                section_siblings = list(section_soup.section.previous_siblings)[::-1]
                section = parse_section(section_siblings)
            else:
                section = parse_section(section_soup.children)

            # Remove special characters, plus newlines in some url and section names.
            section = section.strip()
            url = section_found["href"].strip().replace("\n", "")
            name = section_found.parent.text.strip()[:-1].replace("\n", "")

            # If text is too long, split into chunks of equal sizes
            if len(section) > max_section_length:
                n_chunks = math.ceil(len(section) / float(max_section_length))
                separator_index = math.floor(len(section) / n_chunks)

                section_chunks = [section[separator_index * i : separator_index * (i + 1)] for i in range(n_chunks)]
                url_chunks = [url] * n_chunks
                name_chunks = [name] * n_chunks

                sections.extend(section_chunks)
                urls.extend(url_chunks)
                names.extend(name_chunks)
            # If text is not too short, add in 1 chunk
            elif len(section) > min_section_length:
                sections.append(section)
                urls.append(url)
                names.append(name)

        return sections, urls, names

    sections = []
    urls = []
    names = []
    for file in files:
        filepath = os.path.join(root_dir, file)
        with open(filepath, "r") as f:
            source = f.read()

        soup = BeautifulSoup(source, "html.parser")
        sections_file, urls_file, names_file = get_all_subsections(soup)
        sections.extend(sections_file)

        urls_file = [base_url + file + url for url in urls_file]
        urls.extend(urls_file)

        names.extend(names_file)

    documents_df = pd.DataFrame.from_dict({"name": names, "url": urls, "text": sections})

    return documents_df


def get_file_extension(filepath: str) -> str:
    return os.path.splitext(filepath)[1]


def write_documents(filepath: str, documents_df: pd.DataFrame):
    ext = get_file_extension(filepath)

    if ext == ".csv":
        documents_df.to_csv(filepath, index=False)
    elif ext in PICKLE_EXTENSIONS:
        documents_df.to_pickle(filepath)
    else:
        raise ValueError(f"Unsupported format: {ext}.")


def read_documents(filepath: str) -> pd.DataFrame:
    ext = get_file_extension(filepath)

    if ext == ".csv":
        return pd.read_csv(filepath)
    elif ext in PICKLE_EXTENSIONS:
        return pd.read_pickle(filepath)
    else:
        raise ValueError(f"Unsupported format: {ext}.")


def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
    encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
    df["n_tokens"] = df.text.apply(lambda x: len(encoding.encode(x)))
    return df


def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
    df["embedding"] = df.text.apply(lambda x: get_embedding(x, engine=EMBEDDING_MODEL))
    return df


def generate_embeddings(filepath: str, output_file: str) -> pd.DataFrame:
    # Get all documents and precompute their embeddings
    df = read_documents(filepath)
    df = compute_n_tokens(df)
    df = precompute_embeddings(df)
    write_documents(output_file, df)
    return df


if __name__ == "__main__":
    root_dir = "/home/hadrien/perso/mila-docs/output/"
    save_filepath = "data/documents.tar.gz"

    # How to write
    documents_df = get_all_documents(root_dir)
    write_documents(save_filepath, documents_df)

    # How to load
    documents_df = read_documents(save_filepath)

    # precompute the document embeddings
    df = generate_embeddings(filepath=save_filepath, output_file="data/document_embeddings.tar.gz")