hfcontext7 / scripts /make_rag_db.py
Abdullah Meda
refactoring edits
c6fe03c
import os
import argparse
import json
import re
from typing import Dict
import dotenv
from pathlib import Path
from tqdm import tqdm
from pymilvus import MilvusClient, model
_ = dotenv.load_dotenv()
def create_collection(client: MilvusClient, collection_name: str, dimension: int):
if client.has_collection(collection_name=collection_name):
client.drop_collection(collection_name=collection_name)
client.create_collection(
collection_name=collection_name,
dimension=dimension,
)
def clean_filename(s):
s = re.sub(r'\d+(?:\.\d+)*\.\s*', '', s) # Remove hierarchical numbering (e.g., "28.", "28.1.")
s = re.sub(r'[^\w\s/.-]', '', s) # Remove emojis
s = re.sub(r'\s+', ' ', s) # Clean up extra spaces
return s.strip()
def main(args: Dict):
client = MilvusClient("milvus.db")
embedding_fn = model.dense.OpenAIEmbeddingFunction(
model_name=args.model_name,
api_key=os.environ.get('OPENAI_API_KEY'),
dimensions=args.dimension
)
create_collection(client, args.collection_name, args.dimension)
with open(args.repos_config_path, "r") as f:
repos = json.load(f)
docs, payloads = [], []
for i, repo in enumerate(repos, 1):
docs_path = Path('docs') / f"{i}. {repo['title']}"
md_file_paths = list(docs_path.rglob('*.md'))
mdx_file_paths = list(docs_path.rglob('*.mdx'))
all_file_paths = md_file_paths + mdx_file_paths
# print(all_file_paths[:5])
for file in all_file_paths:
embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ')
embed_string = clean_filename(embed_string)
docs.append(embed_string)
payloads.append({'file_path': str(file), 'resource': repo['title']})
vectors = embedding_fn.encode_documents(docs)
data = [
{"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]}
for i in range(len(vectors))
]
response = client.insert(collection_name=args.collection_name, data=data)
print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collection_name", type=str, default="hf_docs")
parser.add_argument("--model_name", type=str, default="text-embedding-3-large")
parser.add_argument("--dimension", type=int, default=3072)
parser.add_argument("--docs_dir", type=str, default="docs")
parser.add_argument("--repos_config_path", type=str, default="repos_config.json")
args = parser.parse_args()
if Path('milvus.db').exists():
print("Removing existing Milvus database...")
os.remove('milvus.db')
main(args)