import os import argparse import json import re from typing import Dict import dotenv from pathlib import Path from tqdm import tqdm from pymilvus import MilvusClient, model _ = dotenv.load_dotenv() def create_collection(client: MilvusClient, collection_name: str, dimension: int): if client.has_collection(collection_name=collection_name): client.drop_collection(collection_name=collection_name) client.create_collection( collection_name=collection_name, dimension=dimension, ) def clean_filename(s): s = re.sub(r'\d+(?:\.\d+)*\.\s*', '', s) # Remove hierarchical numbering (e.g., "28.", "28.1.") s = re.sub(r'[^\w\s/.-]', '', s) # Remove emojis s = re.sub(r'\s+', ' ', s) # Clean up extra spaces return s.strip() def main(args: Dict): client = MilvusClient("milvus.db") embedding_fn = model.dense.OpenAIEmbeddingFunction( model_name=args.model_name, api_key=os.environ.get('OPENAI_API_KEY'), dimensions=args.dimension ) create_collection(client, args.collection_name, args.dimension) with open(args.repos_config_path, "r") as f: repos = json.load(f) docs, payloads = [], [] for i, repo in enumerate(repos, 1): docs_path = Path('docs') / f"{i}. {repo['title']}" md_file_paths = list(docs_path.rglob('*.md')) mdx_file_paths = list(docs_path.rglob('*.mdx')) all_file_paths = md_file_paths + mdx_file_paths # print(all_file_paths[:5]) for file in all_file_paths: embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ') embed_string = clean_filename(embed_string) docs.append(embed_string) payloads.append({'file_path': str(file), 'resource': repo['title']}) vectors = embedding_fn.encode_documents(docs) data = [ {"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]} for i in range(len(vectors)) ] response = client.insert(collection_name=args.collection_name, data=data) print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--collection_name", type=str, default="hf_docs") parser.add_argument("--model_name", type=str, default="text-embedding-3-large") parser.add_argument("--dimension", type=int, default=3072) parser.add_argument("--docs_dir", type=str, default="docs") parser.add_argument("--repos_config_path", type=str, default="repos_config.json") args = parser.parse_args() if Path('milvus.db').exists(): print("Removing existing Milvus database...") os.remove('milvus.db') main(args)