Spaces:
Running
Running
File size: 2,811 Bytes
f126864 0955e72 f126864 797c083 f126864 0955e72 f126864 0955e72 f126864 0955e72 f126864 0955e72 f126864 0955e72 f126864 0955e72 f126864 0955e72 f126864 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import argparse
import json
import re
from typing import Dict
import dotenv
from pathlib import Path
from tqdm import tqdm
from pymilvus import MilvusClient, model
_ = dotenv.load_dotenv()
def create_collection(client: MilvusClient, collection_name: str, dimension: int):
if client.has_collection(collection_name=collection_name):
client.drop_collection(collection_name=collection_name)
client.create_collection(
collection_name=collection_name,
dimension=dimension,
)
def clean_filename(s):
s = re.sub(r'\d+(?:\.\d+)*\.\s*', '', s) # Remove hierarchical numbering (e.g., "28.", "28.1.")
s = re.sub(r'[^\w\s/.-]', '', s) # Remove emojis
s = re.sub(r'\s+', ' ', s) # Clean up extra spaces
return s.strip()
def main(args: Dict):
client = MilvusClient("milvus.db")
embedding_fn = model.dense.OpenAIEmbeddingFunction(
model_name=args.model_name,
api_key=os.environ.get('OPENAI_API_KEY'),
dimensions=args.dimension
)
create_collection(client, args.collection_name, args.dimension)
with open(args.repos_config_path, "r") as f:
repos = json.load(f)
docs, payloads = [], []
for i, repo in enumerate(repos, 1):
docs_path = Path('docs') / f"{i}. {repo['title']}"
md_file_paths = list(docs_path.rglob('*.md'))
mdx_file_paths = list(docs_path.rglob('*.mdx'))
all_file_paths = md_file_paths + mdx_file_paths
# print(all_file_paths[:5])
for file in all_file_paths:
embed_string = str(file).replace('docs/', '').replace('.mdx', '').replace('.md', '').replace('/', ' ')
embed_string = clean_filename(embed_string)
docs.append(embed_string)
payloads.append({'file_path': str(file), 'resource': repo['title']})
vectors = embedding_fn.encode_documents(docs)
data = [
{"id": i, "vector": vectors[i], "text": docs[i], **payloads[i]}
for i in range(len(vectors))
]
response = client.insert(collection_name=args.collection_name, data=data)
print(f"Inserted {response['insert_count']} vectors into collection {args.collection_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collection_name", type=str, default="hf_docs")
parser.add_argument("--model_name", type=str, default="text-embedding-3-large")
parser.add_argument("--dimension", type=int, default=3072)
parser.add_argument("--docs_dir", type=str, default="docs")
parser.add_argument("--repos_config_path", type=str, default="repos_config.json")
args = parser.parse_args()
if Path('milvus.db').exists():
print("Removing existing Milvus database...")
os.remove('milvus.db')
main(args) |