hbertrand commited on
Commit
4dcc0d8
·
unverified ·
1 Parent(s): 6e7e500

Support pickle format (#16)

Browse files
Files changed (1) hide show
  1. buster/docparser.py +27 -6
buster/docparser.py CHANGED
@@ -16,6 +16,9 @@ BASE_URL_MILA = "https://docs.mila.quebec/"
16
  BASE_URL_ORION = "https://orion.readthedocs.io/en/stable/"
17
 
18
 
 
 
 
19
  def parse_section(nodes: list[bs4.element.NavigableString]) -> str:
20
  section = []
21
  for node in nodes:
@@ -100,12 +103,30 @@ def get_all_documents(root_dir: str, base_url: str, max_section_length: int = 20
100
  return documents_df
101
 
102
 
 
 
 
 
103
  def write_documents(filepath: str, documents_df: pd.DataFrame):
104
- documents_df.to_csv(filepath, index=False)
 
 
 
 
 
 
 
105
 
106
 
107
  def read_documents(filepath: str) -> pd.DataFrame:
108
- return pd.read_csv(filepath)
 
 
 
 
 
 
 
109
 
110
 
111
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
@@ -119,18 +140,18 @@ def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
119
  return df
120
 
121
 
122
- def generate_embeddings(filepath: str, output_csv: str) -> pd.DataFrame:
123
  # Get all documents and precompute their embeddings
124
  df = read_documents(filepath)
125
  df = compute_n_tokens(df)
126
  df = precompute_embeddings(df)
127
- write_documents(output_csv, df)
128
  return df
129
 
130
 
131
  if __name__ == "__main__":
132
  root_dir = "/home/hadrien/perso/mila-docs/output/"
133
- save_filepath = "data/documents.csv"
134
 
135
  # How to write
136
  documents_df = get_all_documents(root_dir)
@@ -140,4 +161,4 @@ if __name__ == "__main__":
140
  documents_df = read_documents(save_filepath)
141
 
142
  # precompute the document embeddings
143
- df = generate_embeddings(filepath=save_filepath, output_csv="data/document_embeddings.csv")
 
16
  BASE_URL_ORION = "https://orion.readthedocs.io/en/stable/"
17
 
18
 
19
+ PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
20
+
21
+
22
  def parse_section(nodes: list[bs4.element.NavigableString]) -> str:
23
  section = []
24
  for node in nodes:
 
103
  return documents_df
104
 
105
 
106
+ def get_file_extension(filepath: str) -> str:
107
+ return os.path.splitext(filepath)[1]
108
+
109
+
110
  def write_documents(filepath: str, documents_df: pd.DataFrame):
111
+ ext = get_file_extension(filepath)
112
+
113
+ if ext == ".csv":
114
+ documents_df.to_csv(filepath, index=False)
115
+ elif ext in PICKLE_EXTENSIONS:
116
+ documents_df.to_pickle(filepath)
117
+ else:
118
+ raise ValueError(f"Unsupported format: {ext}.")
119
 
120
 
121
  def read_documents(filepath: str) -> pd.DataFrame:
122
+ ext = get_file_extension(filepath)
123
+
124
+ if ext == ".csv":
125
+ return pd.read_csv(filepath)
126
+ elif ext in PICKLE_EXTENSIONS:
127
+ return pd.read_pickle(filepath)
128
+ else:
129
+ raise ValueError(f"Unsupported format: {ext}.")
130
 
131
 
132
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
 
140
  return df
141
 
142
 
143
+ def generate_embeddings(filepath: str, output_file: str) -> pd.DataFrame:
144
  # Get all documents and precompute their embeddings
145
  df = read_documents(filepath)
146
  df = compute_n_tokens(df)
147
  df = precompute_embeddings(df)
148
+ write_documents(output_file, df)
149
  return df
150
 
151
 
152
  if __name__ == "__main__":
153
  root_dir = "/home/hadrien/perso/mila-docs/output/"
154
+ save_filepath = "data/documents.tar.gz"
155
 
156
  # How to write
157
  documents_df = get_all_documents(root_dir)
 
161
  documents_df = read_documents(save_filepath)
162
 
163
  # precompute the document embeddings
164
+ df = generate_embeddings(filepath=save_filepath, output_file="data/document_embeddings.tar.gz")