Spaces:
Running
Running
import os | |
import argparse | |
import json | |
import re | |
from typing import Dict | |
import dotenv | |
from pathlib import Path | |
from tqdm import tqdm | |
from pymilvus import MilvusClient, model | |
_ = dotenv.load_dotenv() | |
def create_collection(client: MilvusClient, collection_name: str, dimension: int): | |
if client.has_collection(collection_name=collection_name): | |
client.drop_collection(collection_name=collection_name) | |
client.create_collection( | |
collection_name=collection_name, | |
dimension=dimension, | |
) | |
def clean_filename(s): | |
s = re.sub(r'\d+(?:\.\d+)*\.\s*', '', s) # Remove hierarchical numbering (e.g., "28.", "28.1.") | |
s = re.sub(r'[^\w\s/.-]', '', s) # Remove emojis | |
s = re.sub(r'\s+', ' ', s) # Clean up extra spaces | |
return s.strip() | |
def main(args: Dict): | |
client = MilvusClient("milvus.db") | |
embedding_fn = model.dense.OpenAIEmbeddingFunction( | |
model_name=args.model_name, | |
api_key=os.environ.get('OPENAI_API_KEY'), | |
dimensions=args.dimension | |
) | |
create_collection(client, args.collection_name, args.dimension) | |
with open(args.repos_config_path, "r") as f: | |
repos = json.load(f) | |
docs, payloads = [], [] | |
for i, repo in enumerate(repos, 1): | |
docs_path = Path('docs') / f"{i}. {repo['title']}" | |
md_file_paths = list(docs_path.rglob('*.md')) | |
mdx_file_paths = list(docs_path.rglob('*.mdx')) | |
all_file_paths = md_file_paths + mdx_file_paths | |
# print(all_file_paths[:5]) | |
for file in all_file_paths: | |
embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ') | |
embed_string = clean_filename(embed_string) | |
docs.append(embed_string) | |
payloads.append({'file_path': str(file), 'resource': repo['title']}) | |
vectors = embedding_fn.encode_documents(docs) | |
data = [ | |
{"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]} | |
for i in range(len(vectors)) | |
] | |
response = client.insert(collection_name=args.collection_name, data=data) | |
print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--collection_name", type=str, default="hf_docs") | |
parser.add_argument("--model_name", type=str, default="text-embedding-3-large") | |
parser.add_argument("--dimension", type=int, default=3072) | |
parser.add_argument("--docs_dir", type=str, default="docs") | |
parser.add_argument("--repos_config_path", type=str, default="repos_config.json") | |
args = parser.parse_args() | |
if Path('milvus.db').exists(): | |
print("Removing existing Milvus database...") | |
os.remove('milvus.db') | |
main(args) |