Geraldine commited on
Commit
5bbaba7
·
verified ·
1 Parent(s): e6c6c2d

Update lancedb_client.py

Browse files
Files changed (1) hide show
  1. lancedb_client.py +129 -100
lancedb_client.py CHANGED
@@ -1,100 +1,129 @@
1
- import numpy as np
2
- import lancedb
3
- import pyarrow as pa
4
- import logging
5
- from dotenv import load_dotenv
6
- import os
7
- import ast
8
-
9
- # Load env vars
10
- load_dotenv(os.path.join(os.getcwd(), ".env"),override = True)
11
- metadata_keys_raw = os.getenv("_DEFAULT_PARSE_METADATA", "").split(",")
12
- metadata_keys = [key.replace(" ", "").replace(")", "").strip("'") for key in metadata_keys_raw]
13
-
14
- # Setup logger
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class LanceDBManager:
20
-
21
- def __init__(self, db_uri="lancedb", embedding_dim=1536):
22
- self.db = lancedb.connect(db_uri)
23
- self.embedding_dim = embedding_dim
24
- logger.info(f"Connected to LanceDB at {db_uri}")
25
-
26
- def _build_schema(self):
27
- """Build LanceDB schema with dynamic metadata fields and embedding vector."""
28
- fields = [
29
- pa.field("id", pa.int64()),
30
- pa.field("item_id", pa.string()),
31
- pa.field("images_urls", pa.string()),
32
- pa.field("text", pa.string()),
33
- pa.field("Cluster", pa.string()),
34
- pa.field("Topic", pa.string()),
35
- pa.field("embeddings", pa.list_(pa.float32(), self.embedding_dim)),
36
- pa.field("umap_embeddings", pa.list_(pa.float32(), 2)),
37
- ]
38
-
39
- # Add fields from metadata
40
- for key in metadata_keys:
41
- sanitized_key = key.split(":")[1].strip().capitalize() # remove the vocabulary prefix in key label and capitalize
42
- fields.append(pa.field(sanitized_key, pa.string()))
43
-
44
- return pa.schema(fields)
45
-
46
- def create_table(self, table_name):
47
- """Create table using dynamic schema."""
48
- try:
49
- schema = self._build_schema()
50
- table = self.db.create_table(table_name, schema=schema)
51
- logger.info(f"Created LanceDB table '{table_name}'")
52
- return table
53
- except Exception as e:
54
- logger.error(f"Failed to create table '{table_name}': {e}")
55
- raise
56
-
57
- def retrieve_table(self, table_name):
58
- try:
59
- table = self.db.open_table(table_name)
60
- logger.info(f"Opened existing LanceDB table '{table_name}'")
61
- return table
62
- except Exception as e:
63
- logger.error(f"Failed to open table '{table_name}': {e}")
64
- raise
65
-
66
- def initialize_table(self, table_name):
67
- try:
68
- return self.retrieve_table(table_name)
69
- except Exception:
70
- logger.info(f"Table '{table_name}' not found. Creating new.")
71
- return self.create_table(table_name)
72
-
73
- def add_entry(self, table_name, items):
74
- table = self.initialize_table(table_name)
75
- table.add(items)
76
- logger.info(f"Added items to table '{table_name}'")
77
-
78
- def list_tables(self):
79
- """List all existing tables in the LanceDB instance."""
80
- try:
81
- tables = self.db.table_names()
82
- logger.info("Retrieved list of tables.")
83
- return tables
84
- except Exception as e:
85
- logger.error(f"Failed to list tables: {e}")
86
- raise
87
-
88
- def get_content_table(self, table_name):
89
- table = self.initialize_table(table_name)
90
- return table.to_pandas()
91
-
92
- def drop_table(self, table_name):
93
- """remove an existing table by name."""
94
- try:
95
- table = self.db.drop_table(table_name)
96
- logger.info(f"Remove existing LanceDB table '{table_name}' successfully.")
97
- return table
98
- except Exception as e:
99
- logger.error(f"Failed to remove existing table '{table_name}': {e}")
100
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import lancedb
3
+ import pyarrow as pa
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ import os
7
+ import ast
8
+
9
+ # Load env vars
10
+ load_dotenv(os.path.join(os.getcwd(), ".env"),override = True)
11
+ metadata_keys_raw = os.getenv("_DEFAULT_PARSE_METADATA", "").split(",")
12
+ metadata_keys = [key.replace(" ", "").replace(")", "").strip("'") for key in metadata_keys_raw]
13
+
14
+ # Setup logger
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class LanceDBManager:
19
+
20
+ def __init__(self, db_uri="lancedb", embedding_dim=768):
21
+ self.db = lancedb.connect(db_uri)
22
+ self.embedding_dim = embedding_dim
23
+ logger.info(f"Connected to LanceDB at {db_uri}")
24
+
25
+ def _build_schema(self):
26
+ """Build LanceDB schema with dynamic metadata fields and embedding vector."""
27
+ fields = [
28
+ pa.field("id", pa.int64()),
29
+ pa.field("item_id", pa.string()),
30
+ pa.field("images_urls", pa.string()),
31
+ pa.field("text", pa.string()),
32
+ pa.field("Cluster", pa.string()),
33
+ pa.field("Topic", pa.string()),
34
+ pa.field("embeddings", pa.list_(pa.float32(), self.embedding_dim)),
35
+ pa.field("umap_embeddings", pa.list_(pa.float32(), 2)),
36
+ ]
37
+
38
+ # Add fields from metadata
39
+ for key in metadata_keys:
40
+ sanitized_key = key.split(":")[1].strip().capitalize() # remove the vocabulary prefix in key label and capitalize
41
+ fields.append(pa.field(sanitized_key, pa.string()))
42
+
43
+ return pa.schema(fields)
44
+
45
+ def create_table(self, table_name):
46
+ """Create table using dynamic schema."""
47
+ try:
48
+ schema = self._build_schema()
49
+ table = self.db.create_table(table_name, schema=schema)
50
+ logger.info(f"Created LanceDB table '{table_name}'")
51
+ return table
52
+ except Exception as e:
53
+ logger.error(f"Failed to create table '{table_name}': {e}")
54
+ raise
55
+
56
+ def retrieve_table(self, table_name):
57
+ try:
58
+ table = self.db.open_table(table_name)
59
+ logger.info(f"Opened existing LanceDB table '{table_name}'")
60
+ return table
61
+ except Exception as e:
62
+ logger.error(f"Failed to open table '{table_name}': {e}")
63
+ raise
64
+
65
+ def initialize_table(self, table_name):
66
+ try:
67
+ return self.retrieve_table(table_name)
68
+ except Exception:
69
+ logger.info(f"Table '{table_name}' not found. Creating new.")
70
+ return self.create_table(table_name)
71
+
72
+ def add_entry(self, table_name, items):
73
+ table = self.initialize_table(table_name)
74
+ table.add(items)
75
+ logger.info(f"Added items to table '{table_name}'")
76
+
77
+ def list_tables(self):
78
+ """List all existing tables in the LanceDB instance."""
79
+ try:
80
+ tables = self.db.table_names()
81
+ logger.info("Retrieved list of tables.")
82
+ return tables
83
+ except Exception as e:
84
+ logger.error(f"Failed to list tables: {e}")
85
+ raise
86
+
87
+ def get_content_table(self, table_name):
88
+ table = self.initialize_table(table_name)
89
+ return table.to_pandas()
90
+
91
+ def drop_table(self, table_name):
92
+ """remove an existing table by name."""
93
+ try:
94
+ table = self.db.drop_table(table_name)
95
+ logger.info(f"Remove existing LanceDB table '{table_name}' successfully.")
96
+ return table
97
+ except Exception as e:
98
+ logger.error(f"Failed to remove existing table '{table_name}': {e}")
99
+ raise
100
+
101
+ def semantic_search(self, table_name, query_embed, limit):
102
+ """
103
+ Perform a semantic search using a provided text query or image.
104
+
105
+ Args:
106
+ query_text (str): The text query for the search.
107
+ query_image_path (str): The path to the image for the search.
108
+ limit (int): The maximum number of results to return.
109
+
110
+ Returns:
111
+ str: JSON string of search results.
112
+ """
113
+ table = self.initialize_table(table_name)
114
+ #https://lancedb.github.io/lancedb/notebooks/DisappearingEmbeddingFunction/
115
+
116
+ try:
117
+ # Perform the search in LanceDB
118
+ results = (table
119
+ .search(query_embed,vector_column_name="embeddings")
120
+ .distance_type("cosine")
121
+ .select(["id"])
122
+ .limit(limit)
123
+ .to_pandas()
124
+ #.sort_values(by='_distance', ascending=True)
125
+ .to_json(orient="records")
126
+ )
127
+ return results
128
+ except Exception as e:
129
+ raise