hfcontext7 / make_rag_db.py
Abdullah Meda
minor edits
797c083
raw
history blame
2.01 kB
import os
import argparse
from typing import Dict
import dotenv
from pathlib import Path
from tqdm import tqdm
from pymilvus import MilvusClient, model
_ = dotenv.load_dotenv()
def create_collection(client: MilvusClient, collection_name: str, dimension: int):
if client.has_collection(collection_name=collection_name):
client.drop_collection(collection_name=collection_name)
client.create_collection(
collection_name=collection_name,
dimension=dimension,
)
def main(args: Dict):
client = MilvusClient("milvus.db")
embedding_fn = model.dense.OpenAIEmbeddingFunction(
model_name=args.model_name,
api_key=os.environ.get('OPENAI_API_KEY'),
dimensions=args.dimension
)
create_collection(client, args.collection_name, args.dimension)
docs = Path(args.docs_dir)
md_file_paths = list(docs.rglob('*.md'))
mdx_file_paths = list(docs.rglob('*.mdx'))
all_file_paths = md_file_paths + mdx_file_paths
docs, payloads = [], []
for file in tqdm(all_file_paths):
embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ')
docs.append(embed_string)
payloads.append({'file_path': str(file)})
vectors = embedding_fn.encode_documents(docs)
data = [
{"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]}
for i in range(len(vectors))
]
response = client.insert(collection_name=args.collection_name, data=data)
print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collection_name", type=str, default="hf_docs")
parser.add_argument("--model_name", type=str, default="text-embedding-3-small")
parser.add_argument("--dimension", type=int, default=1536)
parser.add_argument("--docs_dir", type=str, default="docs")
args = parser.parse_args()
main(args)