File size: 2,811 Bytes
f126864
 
0955e72
 
f126864
 
 
 
 
 
797c083
f126864
 
 
 
 
 
 
 
 
 
 
0955e72
 
 
 
 
 
 
 
f126864
 
 
 
 
 
 
 
 
 
 
0955e72
 
 
f126864
0955e72
f126864
0955e72
 
 
 
 
 
 
 
 
 
 
 
 
f126864
 
 
 
 
 
 
 
 
 
 
 
 
 
0955e72
 
f126864
0955e72
f126864
 
0955e72
 
 
 
f126864
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import argparse
import json
import re
from typing import Dict
import dotenv
from pathlib import Path
from tqdm import tqdm
from pymilvus import MilvusClient, model

_ = dotenv.load_dotenv()


def create_collection(client: MilvusClient, collection_name: str, dimension: int):
    if client.has_collection(collection_name=collection_name):
        client.drop_collection(collection_name=collection_name)

    client.create_collection(
        collection_name=collection_name,
        dimension=dimension,
    )


def clean_filename(s):
    s = re.sub(r'\d+(?:\.\d+)*\.\s*', '', s) # Remove hierarchical numbering (e.g., "28.", "28.1.")
    s = re.sub(r'[^\w\s/.-]', '', s)         # Remove emojis
    s = re.sub(r'\s+', ' ', s)               # Clean up extra spaces
    return s.strip()


def main(args: Dict):
    client = MilvusClient("milvus.db")

    embedding_fn = model.dense.OpenAIEmbeddingFunction(
        model_name=args.model_name,
        api_key=os.environ.get('OPENAI_API_KEY'),
        dimensions=args.dimension
    )

    create_collection(client, args.collection_name, args.dimension)

    with open(args.repos_config_path, "r") as f:
        repos = json.load(f)

    docs, payloads = [], []
    for i, repo in enumerate(repos, 1):

        docs_path = Path('docs') / f"{i}. {repo['title']}"
        md_file_paths  = list(docs_path.rglob('*.md'))
        mdx_file_paths = list(docs_path.rglob('*.mdx'))
        all_file_paths = md_file_paths + mdx_file_paths

        # print(all_file_paths[:5])

        for file in all_file_paths:
            embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ')
            embed_string = clean_filename(embed_string)

            docs.append(embed_string)
            payloads.append({'file_path': str(file), 'resource': repo['title']})

    vectors = embedding_fn.encode_documents(docs)

    data = [
        {"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]}
        for i in range(len(vectors))
    ]

    response = client.insert(collection_name=args.collection_name, data=data)
    print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--collection_name", type=str, default="hf_docs")
    parser.add_argument("--model_name", type=str, default="text-embedding-3-large")
    parser.add_argument("--dimension", type=int, default=3072)
    parser.add_argument("--docs_dir", type=str, default="docs")
    parser.add_argument("--repos_config_path", type=str, default="repos_config.json")
    args = parser.parse_args()

    if Path('milvus.db').exists():
        print("Removing existing Milvus database...")
        os.remove('milvus.db')

    main(args)