MilanM commited on
Commit
5008fbd
·
verified ·
1 Parent(s): 5293c6e

Upload visualized_app_v2_2.py

Browse files
Files changed (1) hide show
  1. visualized_app_v2_2.py +2369 -0
visualized_app_v2_2.py ADDED
@@ -0,0 +1,2369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # [tool.marimo.display]
3
+ # custom_css = ["./custom_header_font.css"]
4
+ # ///
5
+
6
+ import marimo
7
+
8
+ __generated_with = "0.13.14"
9
+ app = marimo.App(width="full")
10
+
11
+ with app.setup:
12
+ # Initialization code that runs before all other cells
13
+ import marimo as mo
14
+ from typing import Dict, Optional, List, Union, Any
15
+ from ibm_watsonx_ai import APIClient, Credentials
16
+ from pathlib import Path
17
+ import pandas as pd
18
+ import mimetypes
19
+ import requests
20
+ import zipfile
21
+ import tempfile
22
+ import certifi
23
+ import base64
24
+ import polars
25
+ import nltk
26
+ import time
27
+ import json
28
+ import ast
29
+ import os
30
+ import io
31
+ import re
32
+
33
+ from dotenv import load_dotenv
34
+ load_dotenv()
35
+
36
+ def get_iam_token(api_key):
37
+ return requests.post(
38
+ 'https://iam.cloud.ibm.com/identity/token',
39
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
40
+ data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key},
41
+ verify=certifi.where()
42
+ ).json()['access_token']
43
+
44
+ def setup_task_credentials(client):
45
+ # Get existing task credentials
46
+ existing_credentials = client.task_credentials.get_details()
47
+
48
+ # Delete existing credentials if any
49
+ if "resources" in existing_credentials and existing_credentials["resources"]:
50
+ for cred in existing_credentials["resources"]:
51
+ cred_id = client.task_credentials.get_id(cred)
52
+ client.task_credentials.delete(cred_id)
53
+
54
+ # Store new credentials
55
+ return client.task_credentials.store()
56
+
57
+
58
+
59
+ @app.cell
60
+ def client_variables(client_instantiation_form):
61
+ client_setup = client_instantiation_form.value or None
62
+
63
+ ### Extract Credential Variables:
64
+ if client_setup:
65
+ wx_url = client_setup["wx_region"] if client_setup["wx_region"] else "EU"
66
+ wx_api_key = client_setup["wx_api_key"].strip() if client_setup["wx_api_key"] else None
67
+ os.environ["WATSONX_APIKEY"] = wx_api_key or ""
68
+
69
+ project_id = client_setup["project_id"].strip() if client_setup["project_id"] else None
70
+ space_id = client_setup["space_id"].strip() if client_setup["space_id"] else None
71
+ else:
72
+ os.environ["WATSONX_APIKEY"] = ""
73
+ project_id = space_id = wx_api_key = wx_url = None
74
+ return client_setup, project_id, space_id, wx_api_key, wx_url
75
+
76
+
77
+ @app.cell
78
+ def _():
79
+ from baked_in_credentials.creds import credentials
80
+ from base_variables import wx_regions, wx_platform_url
81
+ from helper_functions.helper_functions import wrap_with_spaces, get_key_by_value, markdown_spacing, get_cell_values, create_parameter_table, get_cred_value
82
+ return (
83
+ create_parameter_table,
84
+ credentials,
85
+ get_cell_values,
86
+ get_cred_value,
87
+ get_key_by_value,
88
+ wrap_with_spaces,
89
+ wx_regions,
90
+ )
91
+
92
+
93
+ @app.cell
94
+ def client_instantiation(
95
+ client_setup,
96
+ project_id,
97
+ space_id,
98
+ wx_api_key,
99
+ wx_url,
100
+ ):
101
+ ### Instantiate the watsonx.ai client
102
+ if client_setup:
103
+ try:
104
+ wx_credentials = Credentials(url=wx_url, api_key=wx_api_key)
105
+ project_client = (
106
+ APIClient(credentials=wx_credentials, project_id=project_id)
107
+ if project_id
108
+ else None
109
+ )
110
+ deployment_client = (
111
+ APIClient(credentials=wx_credentials, space_id=space_id)
112
+ if space_id
113
+ else None
114
+ )
115
+ instantiation_success = True
116
+ instantiation_error = None
117
+ except Exception as e:
118
+ instantiation_success = False
119
+ instantiation_error = str(e)
120
+ wx_credentials = project_client = deployment_client = None
121
+ else:
122
+ wx_credentials = project_client = deployment_client = None
123
+ instantiation_success = None
124
+ instantiation_error = None
125
+
126
+ return (
127
+ deployment_client,
128
+ instantiation_error,
129
+ instantiation_success,
130
+ project_client,
131
+ )
132
+
133
+
134
+ @app.cell
135
+ def _():
136
+ mo.md(
137
+ r"""
138
+ #watsonx.ai Embedding Visualizer - Marimo Notebook
139
+
140
+ #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages.
141
+
142
+ #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want.
143
+ <br>
144
+
145
+ /// admonition
146
+ Created by ***Milan Mrdenovic*** [[email protected]] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025*
147
+ ///
148
+
149
+
150
+ >Licensed under apache 2.0, users hold full accountability for any use or modification of the code.
151
+ ><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter.
152
+
153
+ <br>
154
+ """
155
+ )
156
+ return
157
+
158
+
159
+ @app.cell
160
+ def _():
161
+ mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""")
162
+ return
163
+
164
+
165
+ @app.cell
166
+ def accordion_client_setup(
167
+ client_selector,
168
+ client_stack,
169
+ current_mode,
170
+ switch_file_loader_type,
171
+ ):
172
+ ui_accordion_part_1_1 = mo.accordion(
173
+ {
174
+ "Instantiate Client": mo.vstack([
175
+ client_stack,
176
+ mo.hstack([client_selector, switch_file_loader_type], justify="space-around", gap=2),
177
+ current_mode
178
+ ], align="center"),
179
+ }
180
+ )
181
+
182
+ ui_accordion_part_1_1
183
+ return
184
+
185
+
186
+ @app.cell
187
+ def _(switch_file_loader_type):
188
+ if switch_file_loader_type.value:
189
+ current_mode = mo.md("**Current Mode:** Using pre-made embedding/text files.")
190
+ else:
191
+ current_mode = mo.md("**Current Mode:** Using loaded pdf files and chunking.")
192
+ return (current_mode,)
193
+
194
+
195
+ @app.cell
196
+ def accordion_file_upload(select_stack):
197
+ ui_accordion_part_1_2 = mo.accordion(
198
+ {
199
+ "Select Model & Upload Files": select_stack
200
+ }
201
+ )
202
+
203
+ ui_accordion_part_1_2
204
+ return
205
+
206
+
207
+ @app.cell
208
+ def loaded_texts(
209
+ create_temp_files_from_uploads,
210
+ file_loader,
211
+ pdf_reader,
212
+ run_upload_button,
213
+ set_text_state,
214
+ switch_file_loader_type,
215
+ ):
216
+ if file_loader.value is not None and run_upload_button.value:
217
+ filepaths = create_temp_files_from_uploads(file_loader.value)
218
+ if switch_file_loader_type.value:
219
+ loaded_texts = load_json_csv_data_with_progress(filepaths, file_loader.value, show_progress=True)
220
+ else:
221
+ loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True)
222
+
223
+ set_text_state(loaded_texts)
224
+ else:
225
+ filepaths = None
226
+ loaded_texts = None
227
+ return (loaded_texts,)
228
+
229
+
230
+ @app.cell
231
+ def _(chunker_setup, file_column_setup, switch_file_loader_type):
232
+ if switch_file_loader_type.value:
233
+ ui_accordion_part_1_3 = mo.accordion(
234
+ {
235
+ "Column Selector": file_column_setup
236
+ }
237
+ )
238
+ else:
239
+ ui_accordion_part_1_3 = mo.accordion(
240
+ {
241
+ "Chunker Setup": chunker_setup
242
+ }
243
+ )
244
+
245
+ ui_accordion_part_1_3
246
+ return
247
+
248
+
249
+ @app.cell
250
+ def accordion_chunker_setup():
251
+ # ui_accordion_part_1_3 = mo.accordion(
252
+ # {
253
+ # "Chunker Setup": chunker_setup
254
+ # }
255
+ # )
256
+
257
+ # ui_accordion_part_1_3
258
+ return
259
+
260
+
261
+ @app.cell
262
+ def chunk_documents_to_nodes(
263
+ get_text_state,
264
+ sentence_splitter,
265
+ sentence_splitter_config,
266
+ set_chunk_state,
267
+ ):
268
+ if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None:
269
+ chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True)
270
+ set_chunk_state(chunked_texts)
271
+ else:
272
+ chunked_texts = None
273
+ return (chunked_texts,)
274
+
275
+
276
+ @app.cell
277
+ def _():
278
+ mo.md(r"""###Part 2 - Query Setup and Visualization""")
279
+ return
280
+
281
+
282
+ @app.cell
283
+ def accordion_chunk_range():
284
+ # ui_accordion_part_2_1 = mo.accordion(
285
+ # {
286
+ # "Chunk Range Selection": chart_range_selection
287
+ # }
288
+ # )
289
+ # ui_accordion_part_2_1
290
+ return
291
+
292
+
293
+ @app.cell
294
+ def _(chart_range_selection, switch_file_loader_type):
295
+ ui_accordion_part_2_1 = mo.accordion(
296
+ {
297
+ "Chunk Range Selection": chart_range_selection
298
+ }
299
+ )
300
+ ui_accordion_part_2_1 if switch_file_loader_type.value == False else None
301
+ return
302
+
303
+
304
+ @app.cell
305
+ def chunk_embedding(
306
+ chunks_to_process,
307
+ embedding,
308
+ sentence_splitter_config,
309
+ set_embedding_state,
310
+ ):
311
+ if sentence_splitter_config.value is not None and chunks_to_process is not None:
312
+ with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner:
313
+ output_embeddings = embedding.embed_documents(chunks_to_process)
314
+ _spinner.update("Almost Done")
315
+ time.sleep(1.5)
316
+ set_embedding_state(output_embeddings)
317
+ _spinner.update("Documents Embedded")
318
+ else:
319
+ output_embeddings = None
320
+ return
321
+
322
+
323
+ @app.cell
324
+ def preview_chunks(chunks_dict):
325
+ if chunks_dict is not None:
326
+ stats = create_stats(chunks_dict,
327
+ bordered=True,
328
+ object_names=['text','text'],
329
+ group_by_row=True,
330
+ items_per_row=5,
331
+ gap=1,
332
+ label="Chunk")
333
+ ui_chunk_viewer = mo.accordion(
334
+ {
335
+ "View Chunks": stats,
336
+ }
337
+ )
338
+ else:
339
+ ui_chunk_viewer = None
340
+
341
+ ui_chunk_viewer
342
+ return
343
+
344
+
345
+ @app.cell
346
+ def accordion_query_view(chart_visualization, query_stack):
347
+ ui_accordion_part_2_2 = mo.accordion(
348
+ {
349
+ "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3)
350
+ }
351
+ )
352
+ ui_accordion_part_2_2
353
+ return
354
+
355
+
356
+ @app.cell
357
+ def chunker_setup(sentence_splitter_config):
358
+ chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55])
359
+ return (chunker_setup,)
360
+
361
+
362
+ @app.cell
363
+ def file_and_model_select(
364
+ file_loader,
365
+ get_embedding_model_list,
366
+ run_upload_button,
367
+ ):
368
+ select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([mo.md("Drag & Drop or Double Click to select PDFs, then press **Load Files**"),file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3])
369
+ return (select_stack,)
370
+
371
+
372
+ @app.cell
373
+ def client_instantiation_form(credentials, get_cred_value, wx_regions):
374
+ baked_in_creds = credentials
375
+ # Create a form with multiple elements
376
+ client_instantiation_form = (
377
+ mo.md('''
378
+ ###**watsonx.ai credentials:**
379
+
380
+ {wx_region}
381
+
382
+ {wx_api_key}
383
+
384
+ {project_id}
385
+
386
+ {space_id}
387
+
388
+ > You can add either a project_id, space_id or both, **only one is required**.
389
+ > If you provide both you can switch the active one in the dropdown.
390
+ ''')
391
+ .batch(
392
+ wx_region = mo.ui.dropdown(
393
+ wx_regions,
394
+ label="Select your watsonx.ai region:",
395
+ value=get_cred_value('region', creds_var_name='baked_in_creds') or "EU",
396
+ searchable=True
397
+ ),
398
+ wx_api_key = mo.ui.text(
399
+ placeholder="Add your IBM Cloud api-key...",
400
+ label="IBM Cloud Api-key:",
401
+ kind="password",
402
+ value=get_cred_value('api_key', creds_var_name='baked_in_creds')
403
+ ),
404
+ project_id = mo.ui.text(
405
+ placeholder="Add your watsonx.ai project_id...",
406
+ label="Project_ID:",
407
+ kind="text",
408
+ value=get_cred_value('project_id', creds_var_name='baked_in_creds')
409
+ ),
410
+ space_id = mo.ui.text(
411
+ placeholder="Add your watsonx.ai space_id...",
412
+ label="Space_ID:",
413
+ kind="text",
414
+ value=get_cred_value('space_id', creds_var_name='baked_in_creds')
415
+ )
416
+ ,)
417
+ .form(show_clear_button=True, bordered=False)
418
+ )
419
+ return (client_instantiation_form,)
420
+
421
+
422
+ @app.cell
423
+ def instantiation_status(
424
+ client_callout_kind,
425
+ client_instantiation_form,
426
+ client_status,
427
+ ):
428
+ client_callout = mo.callout(client_status, kind=client_callout_kind)
429
+ client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10)
430
+ return (client_stack,)
431
+
432
+
433
+ @app.cell
434
+ def _(
435
+ client_key,
436
+ client_options,
437
+ client_selector,
438
+ client_setup,
439
+ get_key_by_value,
440
+ instantiation_error,
441
+ instantiation_success,
442
+ wrap_with_spaces,
443
+ ):
444
+ active_client_name = get_key_by_value(client_options, client_key)
445
+
446
+ if client_setup:
447
+ if instantiation_success:
448
+ client_status = mo.md(
449
+ f"### ✅ Client Instantiation Successful ✅\n\n"
450
+ f"{client_selector}\n\n"
451
+ f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}"
452
+ )
453
+ client_callout_kind = "success"
454
+ else:
455
+ client_status = mo.md(
456
+ f"### ❌ Client Instantiation Failed\n**Error:** {instantiation_error}\n\nCheck your region selection and credentials"
457
+ )
458
+ client_callout_kind = "danger"
459
+ else:
460
+ client_status = mo.md(
461
+ f"### Client Instantiation Status will turn Green When Ready\n\n"
462
+ f"{client_selector}\n\n"
463
+ f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}"
464
+ )
465
+ client_callout_kind = "neutral"
466
+
467
+ return client_callout_kind, client_status
468
+
469
+
470
+ @app.cell
471
+ def client_selector(deployment_client, project_client):
472
+ if project_client is not None and deployment_client is not None:
473
+ client_options = {"Project Client":project_client, "Deployment Client":deployment_client}
474
+
475
+ elif project_client is not None:
476
+ client_options = {"Project Client":project_client}
477
+
478
+ elif deployment_client is not None:
479
+ client_options = {"Deployment Client":deployment_client}
480
+
481
+ else:
482
+ client_options = {"No Client": "Instantiate a Client"}
483
+
484
+ default_client = next(iter(client_options))
485
+ client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Switch your active client:**")
486
+ return client_options, client_selector
487
+
488
+
489
+ @app.cell
490
+ def active_client(client_selector):
491
+ client_key = client_selector.value
492
+ if client_key == "Instantiate a Client":
493
+ client = None
494
+ else:
495
+ client = client_key
496
+ return client, client_key
497
+
498
+
499
+ @app.cell
500
+ def emb_model_selection(client, set_embedding_model_list):
501
+ if client is not None:
502
+ model_specs = client.foundation_models.get_embeddings_model_specs()
503
+ # model_specs = client.foundation_models.get_model_specs()
504
+ resources = model_specs["resources"]
505
+ # Define embedding models reference data
506
+ embedding_models = {
507
+ "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384},
508
+ "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768},
509
+ "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768},
510
+ "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768},
511
+ "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384},
512
+ "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384},
513
+ "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384},
514
+ "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384},
515
+ "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024}
516
+ }
517
+
518
+ # Get model IDs from resources
519
+ model_id_list = []
520
+ for resource in resources:
521
+ model_id_list.append(resource["model_id"])
522
+
523
+ # Create enhanced model data for the table
524
+ embedding_model_data = []
525
+ for model_id in model_id_list:
526
+ model_entry = {"model_id": model_id}
527
+
528
+ # Add properties if model exists in our reference, otherwise use 0
529
+ if model_id in embedding_models:
530
+ model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"]
531
+ model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"]
532
+ else:
533
+ model_entry["max_tokens"] = 0
534
+ model_entry["embedding_dimensions"] = 0
535
+
536
+ embedding_model_data.append(model_entry)
537
+
538
+ embedding_model_selection = mo.ui.table(
539
+ embedding_model_data,
540
+ selection="single", # Only allow selecting one row
541
+ label="Select an embedding model to use.",
542
+ page_size=30,
543
+ initial_selection=[1]
544
+ )
545
+ set_embedding_model_list(embedding_model_selection)
546
+ else:
547
+ default_model_data = [{
548
+ "model_id": "ibm/granite-embedding-107m-multilingual",
549
+ "max_tokens": 512,
550
+ "embedding_dimensions": 384
551
+ }]
552
+
553
+ set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use."))
554
+ return
555
+
556
+
557
+ @app.function
558
+ def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."):
559
+ embedding_model_selection = mo.ui.table(
560
+ model_data,
561
+ selection=selection_type, # Only allow selecting one row
562
+ label=label,
563
+ page_size=30,
564
+ initial_selection=[initial_selection]
565
+ )
566
+ return embedding_model_selection
567
+
568
+
569
+ @app.cell
570
+ def embedding_model():
571
+ get_embedding_model_list, set_embedding_model_list = mo.state(None)
572
+ return get_embedding_model_list, set_embedding_model_list
573
+
574
+
575
+ @app.cell
576
+ def emb_model_parameters(emb_model_max_tk, embedding_model):
577
+ from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
578
+ if embedding_model is not None:
579
+ embed_params = {
580
+ EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk,
581
+ EmbedParams.RETURN_OPTIONS: {
582
+ 'input_text': True
583
+ }
584
+ }
585
+ else:
586
+ embed_params = {
587
+ EmbedParams.TRUNCATE_INPUT_TOKENS: 128,
588
+ EmbedParams.RETURN_OPTIONS: {
589
+ 'input_text': True
590
+ }
591
+ }
592
+ return (embed_params,)
593
+
594
+
595
+ @app.cell
596
+ def emb_model_state(get_embedding_model_list):
597
+ embedding_model = get_embedding_model_list()
598
+ return (embedding_model,)
599
+
600
+
601
+ @app.cell
602
+ def emb_model_setup(embedding_model):
603
+ if embedding_model is not None:
604
+ emb_model = embedding_model.value[0]['model_id']
605
+ emb_model_max_tk = embedding_model.value[0]['max_tokens']
606
+ emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions']
607
+ else:
608
+ emb_model = None
609
+ emb_model_max_tk = None
610
+ emb_model_emb_dim = None
611
+ return emb_model, emb_model_emb_dim, emb_model_max_tk
612
+
613
+
614
+ @app.cell
615
+ def emb_model_instantiation(client, emb_model, embed_params):
616
+ from ibm_watsonx_ai.foundation_models import Embeddings
617
+ if client is not None:
618
+ embedding = Embeddings(
619
+ model_id=emb_model,
620
+ api_client=client,
621
+ params=embed_params,
622
+ batch_size=1000,
623
+ concurrency_limit=10
624
+ )
625
+ else:
626
+ embedding = None
627
+ return (embedding,)
628
+
629
+
630
+ @app.cell
631
+ def _():
632
+ get_embedding_state, set_embedding_state = mo.state(None)
633
+ return get_embedding_state, set_embedding_state
634
+
635
+
636
+ @app.cell
637
+ def _():
638
+ get_query_state, set_query_state = mo.state(None)
639
+ return get_query_state, set_query_state
640
+
641
+
642
+ @app.cell
643
+ def file_loader_input(switch_file_loader_type):
644
+ if switch_file_loader_type.value:
645
+ file_loader = mo.ui.file(
646
+ kind="area",
647
+ filetypes=[".json",".csv"],
648
+ label=" Load pre-made embedding/text pair files (.json,.csv) ",
649
+ multiple=False
650
+ )
651
+ else:
652
+ file_loader = mo.ui.file(
653
+ kind="area",
654
+ filetypes=[".pdf"],
655
+ label=" Load .pdf files ",
656
+ multiple=True
657
+ )
658
+ return (file_loader,)
659
+
660
+
661
+ @app.cell
662
+ def file_loader_run(file_loader):
663
+ if file_loader.value:
664
+ run_upload_button = mo.ui.run_button(label="Load Files")
665
+ else:
666
+ run_upload_button = mo.ui.run_button(disabled=True, label="Load Files")
667
+ return (run_upload_button,)
668
+
669
+
670
+ @app.cell
671
+ def helper_function_tempfiles():
672
+ def create_temp_files_from_uploads(upload_results) -> List[str]:
673
+ """
674
+ Creates temporary files from a tuple of FileUploadResults objects and returns their paths.
675
+ Args:
676
+ upload_results: Object containing a value attribute that is a tuple of FileUploadResults
677
+ Returns:
678
+ List of temporary file paths
679
+ """
680
+ temp_file_paths = []
681
+
682
+ # Get the number of items in the tuple
683
+ num_items = len(upload_results)
684
+
685
+ # Process each item by index
686
+ for i in range(num_items):
687
+ result = upload_results[i] # Get item by index
688
+
689
+ # Create a temporary file with the original filename
690
+ temp_dir = tempfile.gettempdir()
691
+ file_name = result.name
692
+ temp_path = os.path.join(temp_dir, file_name)
693
+ # Write the contents to the temp file
694
+ with open(temp_path, 'wb') as temp_file:
695
+ temp_file.write(result.contents)
696
+ # Add the path to our list
697
+ temp_file_paths.append(temp_path)
698
+
699
+ return temp_file_paths
700
+
701
+ def cleanup_temp_files(temp_file_paths: List[str]) -> None:
702
+ """Delete temporary files after use."""
703
+ for path in temp_file_paths:
704
+ if os.path.exists(path):
705
+ os.unlink(path)
706
+ return (create_temp_files_from_uploads,)
707
+
708
+
709
+ @app.function
710
+ def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True):
711
+ """
712
+ Loads PDF data for each file path and organizes results by original filename.
713
+ Args:
714
+ pdf_reader: The PyMuPDFReader instance
715
+ filepaths: List of temporary file paths
716
+ file_loader_value: The original upload results value containing file information
717
+ show_progress: Whether to show a progress bar during loading (default: False)
718
+ Returns:
719
+ Dictionary mapping original filenames to their loaded text content
720
+ """
721
+ results = {}
722
+
723
+ # Process files with or without progress bar
724
+ if show_progress:
725
+ import marimo as mo
726
+ # Use progress bar with the length of filepaths as total
727
+ with mo.status.progress_bar(
728
+ total=len(filepaths),
729
+ title="Loading PDFs",
730
+ subtitle="Processing documents...",
731
+ completion_title="PDF Loading Complete",
732
+ completion_subtitle=f"{len(filepaths)} documents processed",
733
+ remove_on_exit=True
734
+ ) as bar:
735
+ # Process each file path
736
+ for i, file_path in enumerate(filepaths):
737
+
738
+ original_file_name = file_loader_value[i].name
739
+ bar.update(subtitle=f"Processing {original_file_name}...")
740
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
741
+
742
+ # Store the result with the original filename as the key
743
+ results[original_file_name] = loaded_text
744
+ # Update progress bar
745
+ bar.update(increment=1)
746
+ else:
747
+ # Original logic without progress bar
748
+ for i, file_path in enumerate(filepaths):
749
+ original_file_name = file_loader_value[i].name
750
+ loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True)
751
+ results[original_file_name] = loaded_text
752
+
753
+ return results
754
+
755
+
756
+ @app.cell
757
+ def file_readers():
758
+ from llama_index.readers.file import PyMuPDFReader
759
+ from llama_index.readers.file import FlatReader
760
+ from llama_index.core.node_parser import SentenceSplitter
761
+
762
+ ### File Readers
763
+ pdf_reader = PyMuPDFReader()
764
+ # flat_file_reader = FlatReader()
765
+ return SentenceSplitter, pdf_reader
766
+
767
+
768
+ @app.cell
769
+ def sentence_splitter_setup():
770
+ ### Chunker Setup
771
+ sentence_splitter_config = (
772
+ mo.md('''
773
+ ###**Chunking Setup:**
774
+
775
+ > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results.
776
+
777
+ Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**.
778
+
779
+ {chunk_size}
780
+
781
+ {chunk_overlap}
782
+
783
+ {separator} {paragraph_separator}
784
+
785
+ {secondary_chunking_regex} {include_metadata}
786
+
787
+ ''')
788
+ .batch(
789
+ chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=275, show_value=True, full_width=True),
790
+ chunk_overlap = mo.ui.slider(start=0, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=0, show_value=True, full_width=True),
791
+ separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "),
792
+ paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator",
793
+ label="**Paragraph Separator:**", kind="text",
794
+ value="\n\n\n"),
795
+ secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex",
796
+ label="**Chunking Regex:**", kind="text",
797
+ value="[^,.;?!]+[,.;?!]?"),
798
+ include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**")
799
+ )
800
+ .form(show_clear_button=True, bordered=False, submit_button_label="Chunk Documents")
801
+ )
802
+ return (sentence_splitter_config,)
803
+
804
+
805
+ @app.cell
806
+ def sentence_splitter_instantiation(
807
+ SentenceSplitter,
808
+ sentence_splitter_config,
809
+ ):
810
+ ### Chunker/Sentence Splitter
811
+ def simple_whitespace_tokenizer(text):
812
+ return text.split()
813
+
814
+ if sentence_splitter_config.value is not None:
815
+ sentence_splitter_config_values = sentence_splitter_config.value
816
+ validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"),
817
+ int(sentence_splitter_config_values.get("chunk_size") * 0.3))
818
+
819
+ sentence_splitter = SentenceSplitter(
820
+ chunk_size=sentence_splitter_config_values.get("chunk_size"),
821
+ chunk_overlap=validated_chunk_overlap,
822
+ separator=sentence_splitter_config_values.get("separator"),
823
+ paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"),
824
+ secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"),
825
+ include_metadata=sentence_splitter_config_values.get("include_metadata"),
826
+ tokenizer=simple_whitespace_tokenizer
827
+ )
828
+
829
+ else:
830
+ sentence_splitter = SentenceSplitter(
831
+ chunk_size=2048,
832
+ chunk_overlap=204,
833
+ separator=" ",
834
+ paragraph_separator="\n\n\n",
835
+ secondary_chunking_regex="[^,.;?!]+[,.;?!]?",
836
+ include_metadata=True,
837
+ tokenizer=simple_whitespace_tokenizer
838
+ )
839
+ return (sentence_splitter,)
840
+
841
+
842
+ @app.cell
843
+ def text_state():
844
+ get_text_state, set_text_state = mo.state(None)
845
+ return get_text_state, set_text_state
846
+
847
+
848
+ @app.cell
849
+ def chunk_state():
850
+ get_chunk_state, set_chunk_state = mo.state(None)
851
+ return get_chunk_state, set_chunk_state
852
+
853
+
854
+ @app.function
855
+ def chunk_documents(loaded_texts, sentence_splitter, show_progress=True):
856
+ """
857
+ Process each document in the loaded_texts dictionary using the sentence_splitter,
858
+ with an optional marimo progress bar tracking progress at document level.
859
+
860
+ Args:
861
+ loaded_texts (dict): Dictionary containing lists of Document objects
862
+ sentence_splitter: The sentence splitter object with get_nodes_from_documents method
863
+ show_progress (bool): Whether to show a progress bar during processing
864
+
865
+ Returns:
866
+ dict: Dictionary with the same structure but containing chunked texts
867
+ """
868
+ chunked_texts_dict = {}
869
+
870
+ # Get the total number of documents across all keys
871
+ total_docs = sum(len(docs) for docs in loaded_texts.values())
872
+ processed_docs = 0
873
+
874
+ # Process with or without progress bar
875
+ if show_progress:
876
+ import marimo as mo
877
+ # Use progress bar with the total number of documents as total
878
+ with mo.status.progress_bar(
879
+ total=total_docs,
880
+ title="Processing Documents",
881
+ subtitle="Chunking documents...",
882
+ completion_title="Processing Complete",
883
+ completion_subtitle=f"{total_docs} documents processed",
884
+ remove_on_exit=True
885
+ ) as bar:
886
+ # Process each key-value pair in the loaded_texts dictionary
887
+ for key, documents in loaded_texts.items():
888
+ # Update progress bar subtitle to show current key
889
+ doc_count = len(documents)
890
+ bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)")
891
+
892
+ # Apply the sentence splitter to each list of documents
893
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
894
+ documents,
895
+ show_progress=False # Disable internal progress to avoid nested bars
896
+ )
897
+
898
+ # Store the result with the same key
899
+ chunked_texts_dict[key] = chunked_texts
900
+ time.sleep(0.15)
901
+
902
+ # Update progress bar with the number of documents in this batch
903
+ bar.update(increment=doc_count)
904
+ processed_docs += doc_count
905
+ else:
906
+ # Process without progress bar
907
+ for key, documents in loaded_texts.items():
908
+ chunked_texts = sentence_splitter.get_nodes_from_documents(
909
+ documents,
910
+ show_progress=True # Use the internal progress bar if no marimo bar
911
+ )
912
+ chunked_texts_dict[key] = chunked_texts
913
+
914
+ return chunked_texts_dict
915
+
916
+
917
+ @app.cell
918
+ def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter):
919
+ if chunked_texts is not None and sentence_splitter:
920
+ chunked_documents = get_chunk_state()
921
+ else:
922
+ chunked_documents = None
923
+ return (chunked_documents,)
924
+
925
+
926
+ @app.cell
927
+ def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi):
928
+ if chunked_documents is not None:
929
+ dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents)
930
+ nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes)
931
+ else:
932
+ dict_from_nodes = None
933
+ nodes_from_dict = None
934
+ return (dict_from_nodes,)
935
+
936
+
937
+ @app.cell
938
+ def chunks_to_process(
939
+ dict_from_nodes,
940
+ document_range_stack,
941
+ get_data_in_range_triplequote,
942
+ ):
943
+ if dict_from_nodes is not None and document_range_stack is not None:
944
+
945
+ chunk_dict_df = create_cumulative_dataframe(dict_from_nodes)
946
+
947
+ if document_range_stack.value is not None:
948
+ chunk_start_idx = document_range_stack.value[0]
949
+ chunk_end_idx = document_range_stack.value[1]
950
+ else:
951
+ chunk_start_idx = 0
952
+ chunk_end_idx = len(chunk_dict_df)
953
+
954
+ chunk_range_index = [chunk_start_idx, chunk_end_idx]
955
+ chunks_dict = get_data_in_range_triplequote(chunk_dict_df,
956
+ index_range=chunk_range_index,
957
+ columns_to_include=["text"])
958
+
959
+ chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else []
960
+ else:
961
+ chunk_objects = None
962
+ chunks_dict = None
963
+ chunks_to_process = None
964
+ return chunks_dict, chunks_to_process
965
+
966
+
967
+ @app.cell
968
+ def helper_function_doc_formatting():
969
+ def llamaindex_convert_docs_multi(items):
970
+ """
971
+ Automatically convert between document objects and dictionaries.
972
+
973
+ This function handles:
974
+ - Converting dictionaries to document objects
975
+ - Converting document objects to dictionaries
976
+ - Processing lists or individual items
977
+ - Supporting dictionary structures where values are lists of documents
978
+
979
+ Args:
980
+ items: A document object, dictionary, or list of either.
981
+ Can also be a dictionary mapping filenames to lists of documents.
982
+
983
+ Returns:
984
+ Converted item(s) maintaining the original structure
985
+ """
986
+ # Handle empty or None input
987
+ if not items:
988
+ return []
989
+
990
+ # Handle dictionary mapping filenames to document lists (from load_pdf_data)
991
+ if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()):
992
+ result = {}
993
+ for filename, doc_list in items.items():
994
+ result[filename] = llamaindex_convert_docs(doc_list)
995
+ return result
996
+
997
+ # Handle single items (not in a list)
998
+ if not isinstance(items, list):
999
+ # Single dictionary to document
1000
+ if isinstance(items, dict):
1001
+ # Determine document class
1002
+ doc_class = None
1003
+ if 'doc_type' in items:
1004
+ import importlib
1005
+ module_path, class_name = items['doc_type'].rsplit('.', 1)
1006
+ module = importlib.import_module(module_path)
1007
+ doc_class = getattr(module, class_name)
1008
+ if not doc_class:
1009
+ from llama_index.core.schema import Document
1010
+ doc_class = Document
1011
+ return doc_class.from_dict(items)
1012
+ # Single document to dictionary
1013
+ elif hasattr(items, 'to_dict'):
1014
+ return items.to_dict()
1015
+ # Return as is if can't convert
1016
+ return items
1017
+
1018
+ # Handle list input
1019
+ result = []
1020
+
1021
+ # Handle empty list
1022
+ if len(items) == 0:
1023
+ return result
1024
+
1025
+ # Determine the type of conversion based on the first non-None item
1026
+ first_item = next((item for item in items if item is not None), None)
1027
+
1028
+ # If we found no non-None items, return empty list
1029
+ if first_item is None:
1030
+ return result
1031
+
1032
+ # Convert dictionaries to documents
1033
+ if isinstance(first_item, dict):
1034
+ # Get the right document class from the items themselves
1035
+ doc_class = None
1036
+ # Try to get doc class from metadata if available
1037
+ if 'doc_type' in first_item:
1038
+ import importlib
1039
+ module_path, class_name = first_item['doc_type'].rsplit('.', 1)
1040
+ module = importlib.import_module(module_path)
1041
+ doc_class = getattr(module, class_name)
1042
+ if not doc_class:
1043
+ # Fallback to default Document class from llama_index
1044
+ from llama_index.core.schema import Document
1045
+ doc_class = Document
1046
+
1047
+ # Convert each dictionary to document
1048
+ for item in items:
1049
+ if isinstance(item, dict):
1050
+ result.append(doc_class.from_dict(item))
1051
+ elif item is None:
1052
+ result.append(None)
1053
+ elif isinstance(item, list):
1054
+ result.append(llamaindex_convert_docs(item))
1055
+ else:
1056
+ result.append(item)
1057
+
1058
+ # Convert documents to dictionaries
1059
+ else:
1060
+ for item in items:
1061
+ if hasattr(item, 'to_dict'):
1062
+ result.append(item.to_dict())
1063
+ elif item is None:
1064
+ result.append(None)
1065
+ elif isinstance(item, list):
1066
+ result.append(llamaindex_convert_docs(item))
1067
+ else:
1068
+ result.append(item)
1069
+
1070
+ return result
1071
+
1072
+ def llamaindex_convert_docs(items):
1073
+ """
1074
+ Automatically convert between document objects and dictionaries.
1075
+
1076
+ Args:
1077
+ items: A list of document objects or dictionaries
1078
+
1079
+ Returns:
1080
+ List of converted items (dictionaries or document objects)
1081
+ """
1082
+ result = []
1083
+
1084
+ # Handle empty or None input
1085
+ if not items:
1086
+ return result
1087
+
1088
+ # Determine the type of conversion based on the first item
1089
+ if isinstance(items[0], dict):
1090
+ # Get the right document class from the items themselves
1091
+ # Look for a 'doc_type' or '__class__' field in the dictionary
1092
+ doc_class = None
1093
+
1094
+ # Try to get doc class from metadata if available
1095
+ if 'doc_type' in items[0]:
1096
+ import importlib
1097
+ module_path, class_name = items[0]['doc_type'].rsplit('.', 1)
1098
+ module = importlib.import_module(module_path)
1099
+ doc_class = getattr(module, class_name)
1100
+
1101
+ if not doc_class:
1102
+ # Fallback to default Document class from llama_index
1103
+ from llama_index.core.schema import Document
1104
+ doc_class = Document
1105
+
1106
+ # Convert dictionaries to documents
1107
+ for item in items:
1108
+ if isinstance(item, dict):
1109
+ result.append(doc_class.from_dict(item))
1110
+ else:
1111
+ # Convert documents to dictionaries
1112
+ for item in items:
1113
+ if hasattr(item, 'to_dict'):
1114
+ result.append(item.to_dict())
1115
+
1116
+ return result
1117
+ return (llamaindex_convert_docs_multi,)
1118
+
1119
+
1120
+ @app.cell
1121
+ def helper_function_create_df():
1122
+ def create_document_dataframes(dict_from_docs):
1123
+ """
1124
+ Creates a pandas DataFrame for each file in the dictionary.
1125
+
1126
+ Args:
1127
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1128
+
1129
+ Returns:
1130
+ List of pandas DataFrames, each representing all documents from a single file
1131
+ """
1132
+ dataframes = []
1133
+
1134
+ for filename, docs in dict_from_docs.items():
1135
+ # Create a list to hold all document records for this file
1136
+ file_records = []
1137
+
1138
+ for i, doc in enumerate(docs):
1139
+ # Convert the document to a format compatible with DataFrame
1140
+ if hasattr(doc, 'to_dict'):
1141
+ doc_data = doc.to_dict()
1142
+ elif isinstance(doc, dict):
1143
+ doc_data = doc
1144
+ else:
1145
+ doc_data = {'content': str(doc)}
1146
+
1147
+ # Add document index information
1148
+ doc_data['doc_index'] = i
1149
+
1150
+ # Add to the list of records for this file
1151
+ file_records.append(doc_data)
1152
+
1153
+ # Create a single DataFrame for all documents in this file
1154
+ if file_records:
1155
+ df = pd.DataFrame(file_records)
1156
+ df['filename'] = filename # Add filename as a column
1157
+ dataframes.append(df)
1158
+
1159
+ return dataframes
1160
+
1161
+ def create_dataframe_previews(dataframe_list, page_size=5):
1162
+ """
1163
+ Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list.
1164
+
1165
+ Args:
1166
+ dataframe_list: List of pandas DataFrames (output from create_document_dataframes)
1167
+ page_size: Number of rows to show per page for each component
1168
+
1169
+ Returns:
1170
+ List of mo.ui.dataframe components
1171
+ """
1172
+ # Create a list of mo.ui.dataframe components
1173
+ preview_components = []
1174
+
1175
+ for df in dataframe_list:
1176
+ # Create a mo.ui.dataframe component for this DataFrame
1177
+ preview = mo.ui.dataframe(df, page_size=page_size)
1178
+ preview_components.append(preview)
1179
+
1180
+ return preview_components
1181
+ return
1182
+
1183
+
1184
+ @app.cell
1185
+ def _():
1186
+ switch_file_loader_type = mo.ui.switch(label="**Switch** to pre-made Embedding/Text pairs")
1187
+ return (switch_file_loader_type,)
1188
+
1189
+
1190
+ @app.cell
1191
+ def _():
1192
+ import csv
1193
+
1194
+ def csv_to_json(csv_file_path, json_file_path):
1195
+ """
1196
+ Convert CSV file to JSON format.
1197
+
1198
+ Args:
1199
+ csv_file_path (str): Path to input CSV file
1200
+ json_file_path (str): Path to output JSON file
1201
+ """
1202
+ with open(csv_file_path, 'r', encoding='utf-8') as csv_file:
1203
+ csv_reader = csv.DictReader(csv_file)
1204
+ data = list(csv_reader)
1205
+
1206
+ with open(json_file_path, 'w', encoding='utf-8') as json_file:
1207
+ json.dump(data, json_file, indent=2)
1208
+
1209
+ return
1210
+
1211
+
1212
+ @app.function
1213
+ def load_json_csv_data_with_progress(filepaths, file_loader_value, show_progress=True):
1214
+ """
1215
+ Loads CSV (converted to JSON) or JSON data for a single file.
1216
+ Returns the raw JSON content without filename mapping.
1217
+ """
1218
+ import csv
1219
+ import json
1220
+
1221
+ filepath = filepaths[0] # Only process first file
1222
+ original_file_name = file_loader_value[0].name
1223
+
1224
+ if show_progress:
1225
+ import marimo as mo
1226
+ with mo.status.progress_bar(
1227
+ total=1,
1228
+ title="Loading File",
1229
+ subtitle=f"Processing {original_file_name}...",
1230
+ completion_title="File Loading Complete",
1231
+ completion_subtitle="1 file processed",
1232
+ remove_on_exit=True
1233
+ ) as bar:
1234
+ if filepath.lower().endswith('.csv'):
1235
+ with open(filepath, 'r', encoding='utf-8') as csv_file:
1236
+ csv_reader = csv.DictReader(csv_file)
1237
+ result = list(csv_reader)
1238
+ elif filepath.lower().endswith('.json'):
1239
+ with open(filepath, 'r', encoding='utf-8') as json_file:
1240
+ result = json.load(json_file)
1241
+ bar.update(increment=1)
1242
+ else:
1243
+ if filepath.lower().endswith('.csv'):
1244
+ with open(filepath, 'r', encoding='utf-8') as csv_file:
1245
+ csv_reader = csv.DictReader(csv_file)
1246
+ result = list(csv_reader)
1247
+ elif filepath.lower().endswith('.json'):
1248
+ with open(filepath, 'r', encoding='utf-8') as json_file:
1249
+ result = json.load(json_file)
1250
+
1251
+ return result
1252
+
1253
+
1254
+ @app.function
1255
+ def organize_data_by_columns(loaded_texts, columns_to_use):
1256
+ """
1257
+ Organizes loaded text data into specified column groups based on configuration.
1258
+
1259
+ Args:
1260
+ loaded_texts: List of dictionaries containing the data
1261
+ columns_to_use: Dictionary mapping column group names to their field configurations
1262
+
1263
+ Returns:
1264
+ Dictionary with column group names as keys and lists of field values as values
1265
+ """
1266
+ result = {}
1267
+
1268
+ for group_name, field_config in columns_to_use.items():
1269
+ result[group_name] = []
1270
+
1271
+ # Get fields that are marked as True
1272
+ selected_fields = [field for field, include in field_config.items() if include]
1273
+
1274
+ for record in loaded_texts:
1275
+ # Add each selected field's value directly to the list
1276
+ for field in selected_fields:
1277
+ if field in record:
1278
+ result[group_name].append(record[field])
1279
+
1280
+ return result
1281
+
1282
+
1283
+ @app.cell
1284
+ def _(
1285
+ create_parameter_table,
1286
+ get_text_state,
1287
+ loaded_texts,
1288
+ switch_file_loader_type,
1289
+ ):
1290
+ if switch_file_loader_type.value and loaded_texts:
1291
+ column_list = list(get_text_state()[0])
1292
+ text_column = create_parameter_table(
1293
+ label="Select the Embedded Text Column",
1294
+ input_list=column_list,
1295
+ column_name="Text Column",
1296
+ selection_type="single-cell",
1297
+ text_justify="center"
1298
+ )
1299
+ embedding_column = create_parameter_table(
1300
+ label="Select the Corresponding Embeddings Column",
1301
+ input_list=column_list,
1302
+ column_name="Embedding Column",
1303
+ selection_type="single-cell",
1304
+ text_justify="center"
1305
+ )
1306
+ column_selection_stack = mo.hstack([text_column, embedding_column], justify="space-around", widths=[0.4,0.4])
1307
+ file_column_setup = mo.hstack([column_selection_stack], justify="space-around", align="center", widths=[0.75])
1308
+ else:
1309
+ text_column = embedding_column = column_selection_stack = file_column_setup = None
1310
+ return embedding_column, file_column_setup, text_column
1311
+
1312
+
1313
+ @app.cell
1314
+ def _(embedding_column, get_cell_values, switch_file_loader_type, text_column):
1315
+ if switch_file_loader_type.value:
1316
+ text_col_value = get_cell_values(text_column)
1317
+ emb_col_value = get_cell_values(embedding_column)
1318
+ columns_to_use = {
1319
+ "texts": text_col_value,
1320
+ "embeddings": emb_col_value
1321
+ }
1322
+ else:
1323
+ text_col_value = emb_col_value = columns_to_use = None
1324
+ return columns_to_use, text_col_value
1325
+
1326
+
1327
+ @app.cell
1328
+ def _(columns_to_use, loaded_texts, text_col_value):
1329
+ if text_col_value and columns_to_use and loaded_texts:
1330
+ premade_documents = organize_data_by_columns(columns_to_use=columns_to_use, loaded_texts=loaded_texts)
1331
+ text_col_state = validate_value(premade_documents["texts"])
1332
+ emb_col_state = validate_value(premade_documents["embeddings"])
1333
+ columns_selected = all_true(text_col_state, emb_col_state)
1334
+ else:
1335
+ premade_documents = text_col_state = emb_col_state = emb_col_state = columns_selected = None
1336
+ return columns_selected, premade_documents
1337
+
1338
+
1339
+ @app.function
1340
+ def validate_value(value):
1341
+ """
1342
+ Check if a value is not None or an empty data object.
1343
+ """
1344
+ return value is not None and bool(value)
1345
+
1346
+
1347
+ @app.cell
1348
+ def helper_function_chart_preparation():
1349
+ import altair as alt
1350
+ import numpy as np
1351
+ import plotly.express as px
1352
+ from sklearn.manifold import TSNE
1353
+
1354
+ def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None):
1355
+ """
1356
+ Prepare embedding data for visualization
1357
+
1358
+ Args:
1359
+ embeddings: List of embeddings arrays
1360
+ texts: List of text strings
1361
+ model_id: Embedding model ID (optional)
1362
+ embedding_dimensions: Embedding dimensions (optional)
1363
+
1364
+ Returns:
1365
+ DataFrame with processed data and metadata
1366
+ """
1367
+ # Flatten embeddings (in case they're nested)
1368
+ flattened_embeddings = []
1369
+ for emb in embeddings:
1370
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1371
+ flattened_embeddings.append(emb[0]) # Take first element if nested
1372
+ else:
1373
+ flattened_embeddings.append(emb)
1374
+
1375
+ # Convert to numpy array
1376
+ embedding_array = np.array(flattened_embeddings)
1377
+
1378
+ # Apply dimensionality reduction (t-SNE)
1379
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1))
1380
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1381
+
1382
+ # Create truncated texts for display
1383
+ truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts]
1384
+
1385
+ # Create dataframe for visualization
1386
+ df = pd.DataFrame({
1387
+ "x": reduced_embeddings[:, 0],
1388
+ "y": reduced_embeddings[:, 1],
1389
+ "text": truncated_texts,
1390
+ "full_text": texts,
1391
+ "index": range(len(texts))
1392
+ })
1393
+
1394
+ # Add metadata
1395
+ metadata = {
1396
+ "model_id": model_id,
1397
+ "embedding_dimensions": embedding_dimensions
1398
+ }
1399
+
1400
+ return df, metadata
1401
+
1402
+ def create_embedding_chart(df, metadata=None):
1403
+ """
1404
+ Create an Altair chart for embedding visualization
1405
+
1406
+ Args:
1407
+ df: DataFrame with x, y coordinates and text
1408
+ metadata: Dictionary with model_id and embedding_dimensions
1409
+
1410
+ Returns:
1411
+ Altair chart
1412
+ """
1413
+ model_id = metadata.get("model_id") if metadata else None
1414
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1415
+
1416
+ selection = alt.selection_multi(fields=['index'])
1417
+
1418
+ base = alt.Chart(df).encode(
1419
+ x=alt.X("x:Q", title="Dimension 1"),
1420
+ y=alt.Y("y:Q", title="Dimension 2"),
1421
+ tooltip=["text", "index"]
1422
+ )
1423
+
1424
+ points = base.mark_circle(size=100).encode(
1425
+ color=alt.Color("index:N", legend=None),
1426
+ opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
1427
+ ).add_selection(selection) # Add this line to apply the selection
1428
+
1429
+ text = base.mark_text(align="left", dx=7).encode(
1430
+ text="index:N"
1431
+ )
1432
+
1433
+ return (points + text).properties(
1434
+ width=700,
1435
+ height=500,
1436
+ title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}"
1437
+ ).interactive()
1438
+
1439
+ def show_selected_text(indices, texts):
1440
+ """
1441
+ Create markdown display for selected texts
1442
+
1443
+ Args:
1444
+ indices: List of selected indices
1445
+ texts: List of all texts
1446
+
1447
+ Returns:
1448
+ Markdown string
1449
+ """
1450
+ if not indices:
1451
+ return "No text selected"
1452
+
1453
+ selected_texts = [texts[i] for i in indices if i < len(texts)]
1454
+ return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)])
1455
+
1456
+ def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None):
1457
+ """
1458
+ Prepare embedding data for 3D visualization
1459
+
1460
+ Args:
1461
+ embeddings: List of embeddings arrays
1462
+ texts: List of text strings
1463
+ model_id: Embedding model ID (optional)
1464
+ embedding_dimensions: Embedding dimensions (optional)
1465
+
1466
+ Returns:
1467
+ DataFrame with processed data and metadata
1468
+ """
1469
+ # Flatten embeddings (in case they're nested)
1470
+ flattened_embeddings = []
1471
+ for emb in embeddings:
1472
+ if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list):
1473
+ flattened_embeddings.append(emb[0])
1474
+ else:
1475
+ flattened_embeddings.append(emb)
1476
+
1477
+ # Convert to numpy array
1478
+ embedding_array = np.array(flattened_embeddings)
1479
+
1480
+ # Handle the case of a single embedding differently
1481
+ if len(embedding_array) == 1:
1482
+ # For a single point, we don't need t-SNE, just use a fixed position
1483
+ reduced_embeddings = np.array([[0.0, 0.0, 0.0]])
1484
+ else:
1485
+ # Apply dimensionality reduction to 3D
1486
+ # Fix: Ensure perplexity is at least 1.0
1487
+ perplexity_value = max(1.0, min(30, len(embedding_array)-1))
1488
+ tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value)
1489
+ reduced_embeddings = tsne.fit_transform(embedding_array)
1490
+
1491
+ # Format texts for display
1492
+ formatted_texts = []
1493
+ for text in texts:
1494
+ # Truncate if needed
1495
+ if len(text) > 500:
1496
+ text = text[:500] + "..."
1497
+
1498
+ # Insert line breaks for wrapping
1499
+ wrapped_text = ""
1500
+ for i in range(0, len(text), 50):
1501
+ wrapped_text += text[i:i+50] + "<br>"
1502
+
1503
+ formatted_texts.append("<b>"+wrapped_text+"</b>")
1504
+
1505
+ # Create dataframe for visualization
1506
+ df = pd.DataFrame({
1507
+ "x": reduced_embeddings[:, 0],
1508
+ "y": reduced_embeddings[:, 1],
1509
+ "z": reduced_embeddings[:, 2],
1510
+ "text": formatted_texts,
1511
+ "full_text": texts,
1512
+ "index": range(len(texts)),
1513
+ "embedding": flattened_embeddings # Store the original embeddings for later use
1514
+ })
1515
+
1516
+ # Add metadata
1517
+ metadata = {
1518
+ "model_id": model_id,
1519
+ "embedding_dimensions": embedding_dimensions
1520
+ }
1521
+
1522
+ return df, metadata
1523
+
1524
+ def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3):
1525
+ """
1526
+ Create a 3D Plotly chart for embedding visualization with proximity-based coloring
1527
+ """
1528
+ model_id = metadata.get("model_id") if metadata else None
1529
+ embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None
1530
+
1531
+ # Calculate the proximity between points
1532
+ from scipy.spatial.distance import pdist, squareform
1533
+ # Get the coordinates as a numpy array
1534
+ coords = df[['x', 'y', 'z']].values
1535
+
1536
+ # Calculate pairwise distances
1537
+ dist_matrix = squareform(pdist(coords))
1538
+
1539
+ # For each point, find its average distance to all other points
1540
+ avg_distances = np.mean(dist_matrix, axis=1)
1541
+
1542
+ # Add this to the dataframe - smaller values = closer to other points
1543
+ df['proximity'] = avg_distances
1544
+
1545
+ # Create 3D scatter plot with proximity-based coloring
1546
+ fig = px.scatter_3d(
1547
+ df,
1548
+ x='x',
1549
+ y='y',
1550
+ z='z',
1551
+ # x='petal_length', # Changed from 'x' to 'petal_length'
1552
+ # y='petal_width', # Changed from 'y' to 'petal_width'
1553
+ # z='petal_height',
1554
+ color='proximity', # Color based on proximity
1555
+ color_continuous_scale='Viridis_r', # Reversed so closer points are warmer colors
1556
+ hover_data=['text', 'index', 'proximity'],
1557
+ labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1558
+ # labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'},
1559
+ title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}",
1560
+ text='index',
1561
+ # size_max=marker_size_var
1562
+ )
1563
+
1564
+ # Update marker size and layout
1565
+ # fig.update_traces(marker=dict(size=3), selector=dict(mode='markers'))
1566
+ fig.update_traces(
1567
+ marker=dict(
1568
+ size=marker_size_var, # Very small marker size
1569
+ opacity=0.7, # Slightly transparent
1570
+ symbol="diamond", # Use circle markers (other options: "square", "diamond", "cross", "x")
1571
+ line=dict(
1572
+ width=0.5, # Very thin border
1573
+ color="white" # White outline makes small dots more visible
1574
+ )
1575
+ ),
1576
+ textfont=dict(
1577
+ color="rgba(255, 255, 255, 0.3)",
1578
+ size=8
1579
+ ),
1580
+ # hovertemplate="<b>index=%{text}</b><br>%{customdata[0]}<br><br>Avg Distance=%{customdata[2]:.4f}<extra></extra>", ### Hover Changes
1581
+ hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>",
1582
+ hoverinfo="text+name",
1583
+ hoverlabel=dict(
1584
+ bgcolor="white", # White background for hover labels
1585
+ font_size=12 # Font size for hover text
1586
+ ),
1587
+ selector=dict(type='scatter3d')
1588
+ )
1589
+
1590
+ # Keep your existing layout settings
1591
+ fig.update_layout(
1592
+ scene=dict(
1593
+ xaxis=dict(
1594
+ title='Dimension 1',
1595
+ nticks=40,
1596
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1597
+ gridcolor="white",
1598
+ showbackground=True,
1599
+ gridwidth=0.35,
1600
+ zerolinecolor="white",
1601
+ ),
1602
+ yaxis=dict(
1603
+ title='Dimension 2',
1604
+ nticks=40,
1605
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1606
+ gridcolor="white",
1607
+ showbackground=True,
1608
+ gridwidth=0.35,
1609
+ zerolinecolor="white",
1610
+ ),
1611
+ zaxis=dict(
1612
+ title='Dimension 3',
1613
+ nticks=40,
1614
+ backgroundcolor="rgb(10, 10, 20, 0.1)",
1615
+ gridcolor="white",
1616
+ showbackground=True,
1617
+ gridwidth=0.35,
1618
+ zerolinecolor="white",
1619
+ ),
1620
+ # Control camera view angle
1621
+ camera=dict(
1622
+ up=dict(x=0, y=0, z=1),
1623
+ center=dict(x=0, y=0, z=0),
1624
+ eye=dict(x=1.25, y=1.25, z=1.25),
1625
+ ),
1626
+ aspectratio=dict(x=1, y=1, z=1),
1627
+ aspectmode='data'
1628
+ ),
1629
+ width=int(chart_width),
1630
+ height=int(chart_height),
1631
+ margin=dict(r=20, l=10, b=10, t=50),
1632
+ paper_bgcolor="rgb(0, 0, 0)",
1633
+ plot_bgcolor="rgb(0, 0, 0)",
1634
+ coloraxis_colorbar=dict(
1635
+ title="Average Distance",
1636
+ thicknessmode="pixels", thickness=20,
1637
+ lenmode="pixels", len=400,
1638
+ yanchor="top", y=1,
1639
+ ticks="outside",
1640
+ dtick=0.1
1641
+ )
1642
+ )
1643
+
1644
+ return fig
1645
+ return create_3d_embedding_chart, prepare_embedding_data_3d
1646
+
1647
+
1648
+ @app.cell
1649
+ def helper_function_text_preparation():
1650
+ def convert_table_to_json_docs(df, selected_columns=None):
1651
+ """
1652
+ Convert a pandas DataFrame or dictionary to a list of JSON documents.
1653
+ Dynamically includes columns based on user selection.
1654
+ Column names are standardized to lowercase with underscores instead of spaces
1655
+ and special characters removed.
1656
+
1657
+ Args:
1658
+ df: The DataFrame or dictionary to process
1659
+ selected_columns: List of column names to include in the output documents
1660
+
1661
+ Returns:
1662
+ list: A list of dictionaries, each representing a row as a JSON document
1663
+ """
1664
+ import pandas as pd
1665
+ import re
1666
+
1667
+ def standardize_key(key):
1668
+ """Convert a column name to lowercase with underscores instead of spaces and no special characters"""
1669
+ if not isinstance(key, str):
1670
+ return str(key).lower()
1671
+ # Replace spaces with underscores and convert to lowercase
1672
+ key = key.lower().replace(' ', '_')
1673
+ # Remove special characters (keeping alphanumeric and underscores)
1674
+ return re.sub(r'[^\w]', '', key)
1675
+
1676
+ # Handle case when input is a dictionary
1677
+ if isinstance(df, dict):
1678
+ # Filter the dictionary to include only selected columns
1679
+ if selected_columns:
1680
+ return [{standardize_key(k): df.get(k, None) for k in selected_columns}]
1681
+ else:
1682
+ # If no columns selected, return all key-value pairs with standardized keys
1683
+ return [{standardize_key(k): v for k, v in df.items()}]
1684
+
1685
+ # Handle case when df is None
1686
+ if df is None:
1687
+ return []
1688
+
1689
+ # Ensure df is a DataFrame
1690
+ if not isinstance(df, pd.DataFrame):
1691
+ try:
1692
+ df = pd.DataFrame(df)
1693
+ except:
1694
+ return [] # Return empty list if conversion fails
1695
+
1696
+ # Now check if DataFrame is empty
1697
+ if df.empty:
1698
+ return []
1699
+
1700
+ # If no columns are specifically selected, use all available columns
1701
+ if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0:
1702
+ selected_columns = list(df.columns)
1703
+
1704
+ # Determine which columns exist in the DataFrame
1705
+ available_columns = []
1706
+ columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)}
1707
+
1708
+ for col in selected_columns:
1709
+ if col in df.columns:
1710
+ available_columns.append(col)
1711
+ elif isinstance(col, str) and col.lower() in columns_lower:
1712
+ available_columns.append(columns_lower[col.lower()])
1713
+
1714
+ # If no valid columns found, return empty list
1715
+ if not available_columns:
1716
+ return []
1717
+
1718
+ # Process rows
1719
+ json_docs = []
1720
+ for _, row in df.iterrows():
1721
+ doc = {}
1722
+ for col in available_columns:
1723
+ value = row[col]
1724
+ # Standardize the column name when adding to document
1725
+ std_col = standardize_key(col)
1726
+ doc[std_col] = None if pd.isna(value) else value
1727
+ json_docs.append(doc)
1728
+
1729
+ return json_docs
1730
+
1731
+ def get_column_values(df, columns_to_include):
1732
+ """
1733
+ Extract values from specified columns of a dataframe as lists.
1734
+
1735
+ Args:
1736
+ df: A pandas DataFrame
1737
+ columns_to_include: A list of column names to extract
1738
+
1739
+ Returns:
1740
+ Dictionary with column names as keys and their values as lists
1741
+ """
1742
+ result = {}
1743
+
1744
+ # Validate that columns exist in the dataframe
1745
+ valid_columns = [col for col in columns_to_include if col in df.columns]
1746
+ invalid_columns = set(columns_to_include) - set(valid_columns)
1747
+
1748
+ if invalid_columns:
1749
+ print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}")
1750
+
1751
+ # Extract values for each valid column
1752
+ for col in valid_columns:
1753
+ result[col] = df[col].tolist()
1754
+
1755
+ return result
1756
+
1757
+ def get_data_in_range(doc_dict_df, index_range, columns_to_include):
1758
+ """
1759
+ Extract values from specified columns of a dataframe within a given index range.
1760
+
1761
+ Args:
1762
+ doc_dict_df: The pandas DataFrame to extract data from
1763
+ index_range: An integer specifying the number of rows to include (from 0 to index_range-1)
1764
+ columns_to_include: A list of column names to extract
1765
+
1766
+ Returns:
1767
+ Dictionary with column names as keys and their values (within the index range) as lists
1768
+ """
1769
+ # Validate the index range
1770
+ max_index = len(doc_dict_df)
1771
+ if index_range <= 0:
1772
+ print(f"Warning: Invalid index range {index_range}. Must be positive.")
1773
+ return {}
1774
+
1775
+ # Adjust index_range if it exceeds the dataframe length
1776
+ if index_range > max_index:
1777
+ print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.")
1778
+ index_range = max_index
1779
+
1780
+ # Slice the dataframe to get rows from 0 to index_range-1
1781
+ df_subset = doc_dict_df.iloc[:index_range]
1782
+
1783
+ # Use the provided get_column_values function to extract column data
1784
+ return get_column_values(df_subset, columns_to_include)
1785
+
1786
+ def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include):
1787
+ """
1788
+ Extract values from specified columns of a dataframe within a given index range.
1789
+ Wraps string values with triple quotes and escapes URLs.
1790
+
1791
+ Args:
1792
+ doc_dict_df: The pandas DataFrame to extract data from
1793
+ index_range: A list of two integers specifying the start and end indices of rows to include
1794
+ (e.g., [0, 10] includes rows from index 0 to 9 inclusive)
1795
+ columns_to_include: A list of column names to extract
1796
+ """
1797
+ # Validate the index range
1798
+ start_idx, end_idx = index_range
1799
+ max_index = len(doc_dict_df)
1800
+
1801
+ # Validate start index
1802
+ if start_idx < 0:
1803
+ print(f"Warning: Invalid start index {start_idx}. Using 0 instead.")
1804
+ start_idx = 0
1805
+
1806
+ # Validate end index
1807
+ if end_idx <= start_idx:
1808
+ print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.")
1809
+ end_idx = start_idx + 1
1810
+
1811
+ # Adjust end index if it exceeds the dataframe length
1812
+ if end_idx > max_index:
1813
+ print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.")
1814
+ end_idx = max_index
1815
+
1816
+ # Slice the dataframe to get rows from start_idx to end_idx-1
1817
+ # Using .loc with slice to preserve original indices
1818
+ df_subset = doc_dict_df.iloc[start_idx:end_idx]
1819
+
1820
+ # Use the provided get_column_values function to extract column data
1821
+ result = get_column_values(df_subset, columns_to_include)
1822
+
1823
+ # Process each string result to wrap in triple quotes
1824
+ for col in result:
1825
+ if isinstance(result[col], list):
1826
+ # Create a new list with items wrapped in triple quotes
1827
+ processed_items = []
1828
+ for item in result[col]:
1829
+ if isinstance(item, str):
1830
+ # Replace http:// and https:// with escaped versions
1831
+ item = item.replace("http://", "http\\://").replace("https://", "https\\://")
1832
+ # processed_items.append('"""' + item + '"""')
1833
+ processed_items.append(item)
1834
+ else:
1835
+ processed_items.append(item)
1836
+ result[col] = processed_items
1837
+ return result
1838
+ return (get_data_in_range_triplequote,)
1839
+
1840
+
1841
+ @app.cell
1842
+ def prepare_doc_select(sentence_splitter_config):
1843
+ def prepare_document_selection(node_dict):
1844
+ """
1845
+ Creates document selection UI component.
1846
+ Args:
1847
+ node_dict: Dictionary mapping filenames to lists of documents
1848
+ Returns:
1849
+ mo.ui component for document selection
1850
+ """
1851
+ # Calculate total number of documents across all files
1852
+ total_docs = sum(len(docs) for docs in node_dict.values())
1853
+
1854
+ # Create a combined DataFrame of all documents for table selection
1855
+ all_docs_records = []
1856
+ doc_index_global = 0
1857
+ for filename, docs in node_dict.items():
1858
+ for i, doc in enumerate(docs):
1859
+ # Convert the document to a format compatible with DataFrame
1860
+ if hasattr(doc, 'to_dict'):
1861
+ doc_data = doc.to_dict()
1862
+ elif isinstance(doc, dict):
1863
+ doc_data = doc
1864
+ else:
1865
+ doc_data = {'content': str(doc)}
1866
+
1867
+ # Add metadata
1868
+ doc_data['filename'] = filename
1869
+ doc_data['doc_index'] = i
1870
+ doc_data['global_index'] = doc_index_global
1871
+ all_docs_records.append(doc_data)
1872
+ doc_index_global += 1
1873
+
1874
+ # Create UI component
1875
+ stop_value = max(total_docs, 1)
1876
+ llama_docs = mo.ui.range_slider(
1877
+ start=1,
1878
+ stop=stop_value,
1879
+ step=1,
1880
+ full_width=True,
1881
+ show_value=True,
1882
+ label="**Select a Range of Chunks to Visualize:**"
1883
+ ).form(submit_button_disabled=check_state(sentence_splitter_config.value), submit_button_label="Change Document View Range")
1884
+
1885
+ return llama_docs
1886
+ return (prepare_document_selection,)
1887
+
1888
+
1889
+ @app.cell
1890
+ def document_range_selection(
1891
+ dict_from_nodes,
1892
+ prepare_document_selection,
1893
+ set_range_slider_state,
1894
+ ):
1895
+ if dict_from_nodes is not None:
1896
+ llama_docs = prepare_document_selection(dict_from_nodes)
1897
+ set_range_slider_state(llama_docs)
1898
+ else:
1899
+ bare_dict = {}
1900
+ llama_docs = prepare_document_selection(bare_dict)
1901
+ return
1902
+
1903
+
1904
+ @app.function
1905
+ def create_cumulative_dataframe(dict_from_docs):
1906
+ """
1907
+ Creates a cumulative DataFrame from a nested dictionary of documents.
1908
+
1909
+ Args:
1910
+ dict_from_docs: Dictionary mapping filenames to lists of documents
1911
+
1912
+ Returns:
1913
+ DataFrame with all documents flattened with global indices
1914
+ """
1915
+ # Create a list to hold all document records
1916
+ all_records = []
1917
+ global_idx = 1 # Start from 1 to match range slider expectations
1918
+
1919
+ for filename, docs in dict_from_docs.items():
1920
+ for i, doc in enumerate(docs):
1921
+ # Convert the document to a dict format
1922
+ if hasattr(doc, 'to_dict'):
1923
+ doc_data = doc.to_dict()
1924
+ elif isinstance(doc, dict):
1925
+ doc_data = doc.copy()
1926
+ else:
1927
+ doc_data = {'content': str(doc)}
1928
+
1929
+ # Add additional metadata
1930
+ doc_data['filename'] = filename
1931
+ doc_data['doc_index'] = i
1932
+ doc_data['global_index'] = global_idx
1933
+
1934
+ # If there's 'content' but no 'text', create a 'text' field
1935
+ if 'content' in doc_data and 'text' not in doc_data:
1936
+ doc_data['text'] = doc_data['content']
1937
+
1938
+ all_records.append(doc_data)
1939
+ global_idx += 1
1940
+
1941
+ # Create DataFrame from all records
1942
+ return pd.DataFrame(all_records)
1943
+
1944
+
1945
+ @app.function
1946
+ def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"):
1947
+ """
1948
+ Create a list of stat objects for each item in the specified dictionary.
1949
+
1950
+ Parameters:
1951
+ - texts_dict (dict): Dictionary containing the text data
1952
+ - bordered (bool): Whether the stats should be bordered
1953
+ - object_names (list or tuple): Two object names to use for label and value
1954
+ [label_object, value_object]
1955
+ - group_by_row (bool): Whether to group stats in rows (horizontal stacks)
1956
+ - items_per_row (int): Number of stat objects per row when group_by_row is True
1957
+
1958
+ Returns:
1959
+ - object: A vertical stack of stat objects or rows of stat objects
1960
+ """
1961
+ if not object_names or len(object_names) < 2:
1962
+ raise ValueError("You must provide two object names as a list or tuple")
1963
+
1964
+ label_object = object_names[0]
1965
+ value_object = object_names[1]
1966
+
1967
+ # Validate that both objects exist in the dictionary
1968
+ if label_object not in texts_dict:
1969
+ raise ValueError(f"Label object '{label_object}' not found in texts_dict")
1970
+ if value_object not in texts_dict:
1971
+ raise ValueError(f"Value object '{value_object}' not found in texts_dict")
1972
+
1973
+ # Determine how many items to process (based on the label object length)
1974
+ num_items = len(texts_dict[label_object])
1975
+
1976
+ # Create individual stat objects
1977
+ individual_stats = []
1978
+ for i in range(num_items):
1979
+ stat = mo.stat(
1980
+ label=texts_dict[label_object][i],
1981
+ value=f"{label} Number: {len(texts_dict[value_object][i])}",
1982
+ bordered=bordered
1983
+ )
1984
+ individual_stats.append(stat)
1985
+
1986
+ # If grouping is not enabled, just return a vertical stack of all stats
1987
+ if not group_by_row:
1988
+ return mo.vstack(individual_stats, wrap=False)
1989
+
1990
+ # Group stats into rows based on items_per_row
1991
+ rows = []
1992
+ for i in range(0, num_items, items_per_row):
1993
+ # Get a slice of stats for this row (up to items_per_row items)
1994
+ row_stats = individual_stats[i:i+items_per_row]
1995
+ # Create a horizontal stack for this row
1996
+ widths = [0.35] * len(row_stats)
1997
+ row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths)
1998
+ rows.append(row)
1999
+
2000
+ # Return a vertical stack of all rows
2001
+ return mo.vstack(rows)
2002
+
2003
+
2004
+ @app.cell
2005
+ def prepare_chart_embeddings(
2006
+ chunks_to_process,
2007
+ emb_model,
2008
+ emb_model_emb_dim,
2009
+ get_embedding_state,
2010
+ prepare_embedding_data_3d,
2011
+ ):
2012
+ # chart_dataframe, chart_metadata = None, None
2013
+ if chunks_to_process is not None and get_embedding_state() is not None:
2014
+ chart_dataframe, chart_metadata = prepare_embedding_data_3d(
2015
+ get_embedding_state(),
2016
+ chunks_to_process,
2017
+ model_id=emb_model,
2018
+ embedding_dimensions=emb_model_emb_dim
2019
+ )
2020
+ else:
2021
+ chart_dataframe, chart_metadata = None, None
2022
+ return chart_dataframe, chart_metadata
2023
+
2024
+
2025
+ @app.function
2026
+ def all_true(*args):
2027
+ """
2028
+ Check if all provided boolean arguments are True.
2029
+ """
2030
+ return all(args)
2031
+
2032
+
2033
+ @app.cell
2034
+ def _(chart_dataframe_prem, columns_selected):
2035
+ print(columns_selected,chart_dataframe_prem)
2036
+ return
2037
+
2038
+
2039
+ @app.cell
2040
+ def _(
2041
+ columns_selected,
2042
+ emb_model,
2043
+ emb_model_emb_dim,
2044
+ premade_documents,
2045
+ prepare_embedding_data_3d,
2046
+ ):
2047
+ if premade_documents and columns_selected:
2048
+ chart_dataframe_prem, chart_metadata_prem = prepare_embedding_data_3d(
2049
+ premade_documents["embeddings"],
2050
+ premade_documents["texts"],
2051
+ model_id=emb_model,
2052
+ embedding_dimensions=emb_model_emb_dim
2053
+ )
2054
+ else:
2055
+ chart_dataframe_prem = chart_metadata_prem = None
2056
+ return chart_dataframe_prem, chart_metadata_prem
2057
+
2058
+
2059
+ @app.cell
2060
+ def chart_dims():
2061
+ chart_dimensions = (
2062
+ mo.md('''
2063
+ > **Adjust Chart Window**
2064
+
2065
+ {chart_height}
2066
+
2067
+ {chat_width}
2068
+
2069
+ ''').batch(
2070
+ chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True),
2071
+ chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True)
2072
+ )
2073
+ )
2074
+ return (chart_dimensions,)
2075
+
2076
+
2077
+ @app.cell
2078
+ def chart_dim_values(chart_dimensions):
2079
+ chart_height = chart_dimensions.value['chart_height']
2080
+ chart_width = chart_dimensions.value['chat_width']
2081
+ return chart_height, chart_width
2082
+
2083
+
2084
+ @app.cell
2085
+ def create_baseline_chart(
2086
+ chart_dataframe,
2087
+ chart_dataframe_prem,
2088
+ chart_height,
2089
+ chart_metadata,
2090
+ chart_metadata_prem,
2091
+ chart_width,
2092
+ create_3d_embedding_chart,
2093
+ ):
2094
+ if chart_dataframe is not None and chart_metadata is not None:
2095
+ emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9)
2096
+ chart = mo.ui.plotly(emb_plot)
2097
+ chart_reference = chart_dataframe
2098
+
2099
+ elif chart_dataframe_prem is not None and chart_metadata_prem is not None:
2100
+ emb_plot = create_3d_embedding_chart(chart_dataframe_prem, chart_metadata_prem, chart_width, chart_height, marker_size_var=9)
2101
+ chart = mo.ui.plotly(emb_plot)
2102
+ chart_reference = chart_dataframe_prem
2103
+
2104
+ else:
2105
+ emb_plot = chart = chart_reference = None
2106
+ return chart, chart_reference, emb_plot
2107
+
2108
+
2109
+ @app.cell
2110
+ def test_query(get_chunk_state, premade_documents, switch_file_loader_type):
2111
+ placeholder = """How can i use watsonx.data to perform vector search?"""
2112
+ if switch_file_loader_type.value:
2113
+ query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True,
2114
+ submit_button_disabled=check_state(premade_documents),
2115
+ submit_button_label="Query and View Visualization")
2116
+ else:
2117
+ query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True,
2118
+ submit_button_disabled=check_state(get_chunk_state()),
2119
+ submit_button_label="Query and View Visualization")
2120
+ return (query,)
2121
+
2122
+
2123
+ @app.cell
2124
+ def query_stack(chart_dimensions, query):
2125
+ # query_stack = mo.hstack([query], justify="space-around", align="center", widths=[0.65])
2126
+ query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15)
2127
+ return (query_stack,)
2128
+
2129
+
2130
+ @app.function
2131
+ def check_state(variable):
2132
+ return variable is None
2133
+
2134
+
2135
+ @app.cell
2136
+ def helper_function_add_query_to_chart():
2137
+ def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12):
2138
+ """
2139
+ Add a query point to an existing 3D embedding chart as a large red dot.
2140
+
2141
+ Args:
2142
+ existing_chart: The existing plotly figure or chart data
2143
+ query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point
2144
+ query_text: Text of the query to display on hover
2145
+ marker_size: Size of the query marker (default: 18, typically 2x other markers)
2146
+
2147
+ Returns:
2148
+ A modified plotly figure with the query point added as a red dot
2149
+ """
2150
+ import plotly.graph_objects as go
2151
+
2152
+ # Create a deep copy of the existing chart to avoid modifying the original
2153
+ import copy
2154
+ chart_copy = copy.deepcopy(existing_chart)
2155
+
2156
+ # Handle case where chart_copy is a dictionary or list (from mo.ui.plotly)
2157
+ if isinstance(chart_copy, (dict, list)):
2158
+ # Create a new plotly figure from the data
2159
+ import plotly.graph_objects as go
2160
+
2161
+ if isinstance(chart_copy, list):
2162
+ # If it's a list, assume it's a list of traces
2163
+ fig = go.Figure(data=chart_copy)
2164
+ else:
2165
+ # If it's a dict with 'data' and 'layout'
2166
+ fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {}))
2167
+
2168
+ chart_copy = fig
2169
+
2170
+ # Create the query trace
2171
+ query_trace = go.Scatter3d(
2172
+ x=[query_coords['x']],
2173
+ y=[query_coords['y']],
2174
+ z=[query_coords['z']],
2175
+ mode='markers',
2176
+ name='Query',
2177
+ marker=dict(
2178
+ size=marker_size, # Typically 2x the size of other markers
2179
+ color='red', # Bright red color
2180
+ symbol='circle', # Circle shape
2181
+ opacity=0.70, # Fully opaque
2182
+ line=dict(
2183
+ width=1, # Thin white border
2184
+ color='white'
2185
+ )
2186
+ ),
2187
+ # text=['Query: ' + query_text],
2188
+ text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], ### Text Wrapping
2189
+ hoverinfo="text+name"
2190
+ )
2191
+
2192
+ # Add the query trace to the chart copy
2193
+ chart_copy.add_trace(query_trace)
2194
+
2195
+ return chart_copy
2196
+
2197
+
2198
+ def get_query_coordinates(reference_embeddings=None, query_embedding=None):
2199
+ """
2200
+ Calculate appropriate coordinates for a query point based on reference embeddings.
2201
+
2202
+ This function handles several scenarios:
2203
+ 1. If both reference embeddings and query embedding are provided, it places the
2204
+ query near similar documents.
2205
+ 2. If only reference embeddings are provided, it places the query at a visible
2206
+ location near the center of the chart.
2207
+ 3. If neither are provided, it returns default origin coordinates.
2208
+
2209
+ Args:
2210
+ reference_embeddings: DataFrame with x, y, z coordinates from the main chart
2211
+ query_embedding: The embedding vector of the query
2212
+
2213
+ Returns:
2214
+ Dictionary with x, y, z coordinates for the query point
2215
+ """
2216
+ import numpy as np
2217
+
2218
+ # Default coordinates (origin with slight offset)
2219
+ default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0}
2220
+
2221
+ # If we don't have reference embeddings, return default
2222
+ if reference_embeddings is None or len(reference_embeddings) == 0:
2223
+ return default_coords
2224
+
2225
+ # If we have reference embeddings but no query embedding,
2226
+ # position at a visible location near the center
2227
+ if query_embedding is None:
2228
+ center_coords = {
2229
+ 'x': reference_embeddings['x'].mean(),
2230
+ 'y': reference_embeddings['y'].mean(),
2231
+ 'z': reference_embeddings['z'].mean()
2232
+ }
2233
+ return center_coords
2234
+
2235
+ # If we have both reference embeddings and query embedding,
2236
+ # try to position near similar documents
2237
+ try:
2238
+ from sklearn.metrics.pairwise import cosine_similarity
2239
+
2240
+ # Check if original embeddings are in the dataframe
2241
+ if 'embedding' in reference_embeddings.columns:
2242
+ # Get all document embeddings as a 2D array
2243
+ if isinstance(reference_embeddings['embedding'].iloc[0], list):
2244
+ doc_embeddings = np.array(reference_embeddings['embedding'].tolist())
2245
+ else:
2246
+ doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values])
2247
+
2248
+ # Reshape query embedding for comparison
2249
+ query_emb_array = np.array(query_embedding)
2250
+ if query_emb_array.ndim == 1:
2251
+ query_emb_array = query_emb_array.reshape(1, -1)
2252
+
2253
+ # Calculate cosine similarities
2254
+ similarities = cosine_similarity(query_emb_array, doc_embeddings)[0]
2255
+
2256
+ # Find the closest document
2257
+ closest_idx = np.argmax(similarities)
2258
+
2259
+ # Use the position of the closest document, with slight offset for visibility
2260
+ query_coords = {
2261
+ 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2,
2262
+ 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2,
2263
+ 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2
2264
+ }
2265
+ return query_coords
2266
+ except Exception as e:
2267
+ print(f"Error positioning query near similar documents: {e}")
2268
+
2269
+ # Fallback to center position if similarity calculation fails
2270
+ center_coords = {
2271
+ 'x': reference_embeddings['x'].mean(),
2272
+ 'y': reference_embeddings['y'].mean(),
2273
+ 'z': reference_embeddings['z'].mean()
2274
+ }
2275
+ return center_coords
2276
+ return add_query_to_embedding_chart, get_query_coordinates
2277
+
2278
+
2279
+ @app.cell
2280
+ def combined_chart_visualization(
2281
+ add_query_to_embedding_chart,
2282
+ chart,
2283
+ chart_reference,
2284
+ emb_plot,
2285
+ embedding,
2286
+ get_query_coordinates,
2287
+ get_query_state,
2288
+ query,
2289
+ set_chart_state,
2290
+ set_query_state,
2291
+ ):
2292
+ # Usage with highlight_closest=True
2293
+ if chart is not None and query.value:
2294
+ with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner:
2295
+ query_emb = embedding.embed_documents([query.value])
2296
+ set_query_state(query_emb)
2297
+
2298
+ _spinner.update("Preparing Query Coordinates") # --- --- ---
2299
+ time.sleep(1.0)
2300
+
2301
+ # Get appropriate coordinates for the query
2302
+ query_coords = get_query_coordinates(
2303
+ reference_embeddings=chart_reference,
2304
+ query_embedding=get_query_state()
2305
+ )
2306
+
2307
+ _spinner.update("Adding Query to Chart") # --- --- ---
2308
+ time.sleep(1.0)
2309
+
2310
+ # Add the query to the chart with closest points highlighted
2311
+ result = add_query_to_embedding_chart(
2312
+ existing_chart=emb_plot,
2313
+ query_coords=query_coords,
2314
+ query_text=query.value,
2315
+ )
2316
+
2317
+ chart_with_query = result
2318
+
2319
+ _spinner.update("Preparing Visualization") # --- --- ---
2320
+ time.sleep(1.0)
2321
+
2322
+ # Create the visualization
2323
+ combined_viz = mo.ui.plotly(chart_with_query)
2324
+ set_chart_state(combined_viz)
2325
+
2326
+ _spinner.update("Done") # --- --- ---
2327
+ else:
2328
+ combined_viz = None
2329
+ return
2330
+
2331
+
2332
+ @app.cell
2333
+ def _():
2334
+ get_range_slider_state, set_range_slider_state = mo.state(None)
2335
+ return get_range_slider_state, set_range_slider_state
2336
+
2337
+
2338
+ @app.cell
2339
+ def _(get_range_slider_state):
2340
+ if get_range_slider_state() is not None:
2341
+ document_range_stack = get_range_slider_state()
2342
+ else:
2343
+ document_range_stack = None
2344
+ return (document_range_stack,)
2345
+
2346
+
2347
+ @app.cell
2348
+ def _():
2349
+ get_chart_state, set_chart_state = mo.state(None)
2350
+ return get_chart_state, set_chart_state
2351
+
2352
+
2353
+ @app.cell
2354
+ def _(get_chart_state, query):
2355
+ if query.value is not None:
2356
+ chart_visualization = get_chart_state()
2357
+ else:
2358
+ chart_visualization = None
2359
+ return (chart_visualization,)
2360
+
2361
+
2362
+ @app.cell
2363
+ def c(document_range_stack):
2364
+ chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65])
2365
+ return (chart_range_selection,)
2366
+
2367
+
2368
+ if __name__ == "__main__":
2369
+ app.run()