mwitiderrick commited on
Commit
ee8097c
·
verified ·
1 Parent(s): f502519

Upload index_miriad_to_qdrant.py

Browse files
Files changed (1) hide show
  1. index_miriad_to_qdrant.py +73 -0
index_miriad_to_qdrant.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # index_miriad_to_qdrant.py
2
+
3
+ from datasets import load_dataset
4
+ from qdrant_client import QdrantClient, models
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ load_dotenv()
9
+
10
+ # Connect to Qdrant Cloud
11
+ client = QdrantClient(
12
+ url=os.environ.get("QDRANT_CLOUD_URL"),
13
+ api_key=os.environ.get("QDRANT_API_KEY"),
14
+ timeout=60.0,
15
+ prefer_grpc=True
16
+ )
17
+
18
+ # Load MIRIAD dataset (sample for demo)
19
+ ds = load_dataset("miriad/miriad-5.8M", split="train").select(range(100000))
20
+
21
+ dense_documents = [
22
+ models.Document(text=doc, model="BAAI/bge-small-en")
23
+ for doc in ds['passage_text']
24
+ ]
25
+
26
+ colbert_documents = [
27
+ models.Document(text=doc, model="colbert-ir/colbertv2.0")
28
+ for doc in ds['passage_text']
29
+ ]
30
+
31
+ collection_name = "medical_chat_bot"
32
+
33
+ # Create collection
34
+ if not client.collection_exists(collection_name):
35
+ client.recreate_collection(
36
+ collection_name=collection_name,
37
+ vectors_config={
38
+ "dense": models.VectorParams(size=384, distance=models.Distance.COSINE),
39
+ "colbert": models.VectorParams(
40
+ size=128,
41
+ distance=models.Distance.COSINE,
42
+ multivector_config=models.MultiVectorConfig(
43
+ comparator=models.MultiVectorComparator.MAX_SIM
44
+ ),
45
+ hnsw_config=models.HnswConfigDiff(m=0) # reranker: no indexing
46
+ )
47
+ }
48
+ )
49
+
50
+ # Batch upload in chunks
51
+ BATCH_SIZE = 3
52
+ points_batch = []
53
+
54
+ for i in range(len(ds['passage_text'])):
55
+ point = models.PointStruct(
56
+ id=i,
57
+ vector={
58
+ "dense": dense_documents[i],
59
+ "colbert": colbert_documents[i]
60
+ },
61
+ payload={"passage_text": ds['passage_text'][i], "paper_id": ds['paper_id'][i]}
62
+ )
63
+ points_batch.append(point)
64
+
65
+ if len(points_batch) == BATCH_SIZE:
66
+ client.upsert(collection_name=collection_name, points=points_batch)
67
+ print(f"Uploaded batch ending at index {i}")
68
+ points_batch = []
69
+
70
+ # Final flush
71
+ if points_batch:
72
+ client.upsert(collection_name=collection_name, points=points_batch)
73
+ print("Uploaded final batch.")