Spaces:
Running
Running
moving share link under input field on the sparse reps tab
Browse files
app.py
CHANGED
|
@@ -9,56 +9,9 @@ import os
|
|
| 9 |
|
| 10 |
# Add this CSS at the top of your file, after the imports
|
| 11 |
css = """
|
| 12 |
-
/*
|
| 13 |
-
.
|
| 14 |
-
|
| 15 |
-
top: 20px !important;
|
| 16 |
-
right: 20px !important;
|
| 17 |
-
z-index: 1000 !important;
|
| 18 |
-
background: #4CAF50 !important;
|
| 19 |
-
color: white !important;
|
| 20 |
-
border-radius: 8px !important;
|
| 21 |
-
padding: 8px 16px !important;
|
| 22 |
-
font-weight: bold !important;
|
| 23 |
-
box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
.share-button:hover {
|
| 27 |
-
background: #45a049 !important;
|
| 28 |
-
transform: translateY(-1px) !important;
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
/* Alternative positions - uncomment the one you want instead */
|
| 32 |
-
|
| 33 |
-
/* Top-left corner */
|
| 34 |
-
/*
|
| 35 |
-
.share-button {
|
| 36 |
-
position: fixed !important;
|
| 37 |
-
top: 20px !important;
|
| 38 |
-
left: 20px !important;
|
| 39 |
-
z-index: 1000 !important;
|
| 40 |
-
}
|
| 41 |
-
*/
|
| 42 |
-
|
| 43 |
-
/* Bottom-right corner (mobile-friendly) */
|
| 44 |
-
/*
|
| 45 |
-
.share-button {
|
| 46 |
-
position: fixed !important;
|
| 47 |
-
bottom: 20px !important;
|
| 48 |
-
right: 20px !important;
|
| 49 |
-
z-index: 1000 !important;
|
| 50 |
-
}
|
| 51 |
-
*/
|
| 52 |
-
|
| 53 |
-
/* Bottom-center */
|
| 54 |
-
/*
|
| 55 |
-
.share-button {
|
| 56 |
-
position: fixed !important;
|
| 57 |
-
bottom: 20px !important;
|
| 58 |
-
left: 50% !important;
|
| 59 |
-
transform: translateX(-50%) !important;
|
| 60 |
-
z-index: 1000 !important;
|
| 61 |
-
}
|
| 62 |
*/
|
| 63 |
"""
|
| 64 |
|
|
@@ -130,11 +83,11 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
|
| 130 |
tokenizer.unk_token_id
|
| 131 |
]:
|
| 132 |
meaningful_token_ids.append(token_id)
|
| 133 |
-
|
| 134 |
if meaningful_token_ids:
|
| 135 |
# Apply mask to the current row in the batch
|
| 136 |
bow_masks[i, list(set(meaningful_token_ids))] = 1
|
| 137 |
-
|
| 138 |
return bow_masks
|
| 139 |
|
| 140 |
|
|
@@ -185,7 +138,7 @@ def get_splade_cocondenser_representation(text):
|
|
| 185 |
|
| 186 |
info_output = f"" # Line 1
|
| 187 |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
|
| 188 |
-
|
| 189 |
|
| 190 |
return formatted_output, info_output
|
| 191 |
|
|
@@ -243,7 +196,7 @@ def get_splade_lexical_representation(text):
|
|
| 243 |
|
| 244 |
info_output = f"" # Line 1
|
| 245 |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
|
| 246 |
-
|
| 247 |
|
| 248 |
return formatted_output, info_output
|
| 249 |
|
|
@@ -260,11 +213,11 @@ def get_splade_doc_representation(text):
|
|
| 260 |
binary_bow_vector = create_lexical_bow_mask(
|
| 261 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 262 |
).squeeze() # Squeeze back for single output
|
| 263 |
-
|
| 264 |
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
|
| 265 |
if not isinstance(indices, list):
|
| 266 |
indices = [indices] if indices else []
|
| 267 |
-
|
| 268 |
values = [1.0] * len(indices) # All values are 1 for binary representation
|
| 269 |
token_weights = dict(zip(indices, values))
|
| 270 |
|
|
@@ -338,12 +291,12 @@ def get_splade_lexical_vector(text):
|
|
| 338 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 339 |
dim=1
|
| 340 |
)[0].squeeze()
|
| 341 |
-
|
| 342 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 343 |
bow_mask = create_lexical_bow_mask(
|
| 344 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 345 |
).squeeze()
|
| 346 |
-
|
| 347 |
splade_vector = splade_vector * bow_mask
|
| 348 |
return splade_vector
|
| 349 |
return None
|
|
@@ -377,7 +330,7 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
| 377 |
values = [1.0] * len(indices)
|
| 378 |
else:
|
| 379 |
values = splade_vector[indices].cpu().tolist()
|
| 380 |
-
|
| 381 |
token_weights = dict(zip(indices, values))
|
| 382 |
|
| 383 |
meaningful_tokens = {}
|
|
@@ -408,8 +361,8 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
| 408 |
|
| 409 |
# This is the line that will now always be split into two
|
| 410 |
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
|
| 411 |
-
|
| 412 |
-
|
| 413 |
return formatted_output, info_output
|
| 414 |
|
| 415 |
|
|
@@ -451,7 +404,7 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
| 451 |
# Combine output into a single string for the Markdown component
|
| 452 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
| 453 |
full_output += "---\n\n"
|
| 454 |
-
|
| 455 |
# Query Representation
|
| 456 |
full_output += f"Query Representation ({query_model_name_display}):\n\n"
|
| 457 |
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
|
|
@@ -460,7 +413,7 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
| 460 |
# Document Representation
|
| 461 |
full_output += f"Document Representation ({doc_model_name_display}):\n\n"
|
| 462 |
full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
|
| 463 |
-
|
| 464 |
return full_output
|
| 465 |
|
| 466 |
|
|
@@ -488,7 +441,13 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
|
|
| 488 |
label="Enter your query or document text here:",
|
| 489 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
| 490 |
)
|
| 491 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
info_output_display = gr.Markdown(
|
| 493 |
value="",
|
| 494 |
label="Vector Information",
|
|
@@ -496,7 +455,7 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
|
|
| 496 |
)
|
| 497 |
with gr.Column(scale=2): # Right column for the main representation output
|
| 498 |
main_representation_output = gr.Markdown()
|
| 499 |
-
|
| 500 |
# Connect the interface elements
|
| 501 |
model_radio.change(
|
| 502 |
fn=predict_representation_explorer,
|
|
@@ -508,15 +467,16 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
|
|
| 508 |
inputs=[model_radio, input_text],
|
| 509 |
outputs=[main_representation_output, info_output_display]
|
| 510 |
)
|
| 511 |
-
|
| 512 |
# Initial call to populate on load (optional, but good for demo)
|
| 513 |
demo.load(
|
| 514 |
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
|
| 515 |
outputs=[main_representation_output, info_output_display]
|
| 516 |
)
|
| 517 |
|
| 518 |
-
with gr.TabItem("Compare Encoders"): #
|
| 519 |
-
|
|
|
|
| 520 |
# Define the common model choices for cleaner code
|
| 521 |
model_choices = [
|
| 522 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
|
@@ -549,7 +509,7 @@ with gr.Blocks(title="SPLADE Demos", css=css) as demo:
|
|
| 549 |
)
|
| 550 |
],
|
| 551 |
outputs=gr.Markdown(),
|
| 552 |
-
allow_flagging="never"
|
| 553 |
)
|
| 554 |
|
| 555 |
demo.launch()
|
|
|
|
| 9 |
|
| 10 |
# Add this CSS at the top of your file, after the imports
|
| 11 |
css = """
|
| 12 |
+
/* The global fixed positioning for the share button is no longer needed
|
| 13 |
+
because we'll place gr.ShareButton directly in the UI.
|
| 14 |
+
You can remove or comment out any previous .share-button CSS if it was there.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
*/
|
| 16 |
"""
|
| 17 |
|
|
|
|
| 83 |
tokenizer.unk_token_id
|
| 84 |
]:
|
| 85 |
meaningful_token_ids.append(token_id)
|
| 86 |
+
|
| 87 |
if meaningful_token_ids:
|
| 88 |
# Apply mask to the current row in the batch
|
| 89 |
bow_masks[i, list(set(meaningful_token_ids))] = 1
|
| 90 |
+
|
| 91 |
return bow_masks
|
| 92 |
|
| 93 |
|
|
|
|
| 138 |
|
| 139 |
info_output = f"" # Line 1
|
| 140 |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
|
| 141 |
+
|
| 142 |
|
| 143 |
return formatted_output, info_output
|
| 144 |
|
|
|
|
| 196 |
|
| 197 |
info_output = f"" # Line 1
|
| 198 |
info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)
|
| 199 |
+
|
| 200 |
|
| 201 |
return formatted_output, info_output
|
| 202 |
|
|
|
|
| 213 |
binary_bow_vector = create_lexical_bow_mask(
|
| 214 |
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 215 |
).squeeze() # Squeeze back for single output
|
| 216 |
+
|
| 217 |
indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
|
| 218 |
if not isinstance(indices, list):
|
| 219 |
indices = [indices] if indices else []
|
| 220 |
+
|
| 221 |
values = [1.0] * len(indices) # All values are 1 for binary representation
|
| 222 |
token_weights = dict(zip(indices, values))
|
| 223 |
|
|
|
|
| 291 |
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 292 |
dim=1
|
| 293 |
)[0].squeeze()
|
| 294 |
+
|
| 295 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
| 296 |
bow_mask = create_lexical_bow_mask(
|
| 297 |
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
|
| 298 |
).squeeze()
|
| 299 |
+
|
| 300 |
splade_vector = splade_vector * bow_mask
|
| 301 |
return splade_vector
|
| 302 |
return None
|
|
|
|
| 330 |
values = [1.0] * len(indices)
|
| 331 |
else:
|
| 332 |
values = splade_vector[indices].cpu().tolist()
|
| 333 |
+
|
| 334 |
token_weights = dict(zip(indices, values))
|
| 335 |
|
| 336 |
meaningful_tokens = {}
|
|
|
|
| 361 |
|
| 362 |
# This is the line that will now always be split into two
|
| 363 |
info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1
|
| 364 |
+
|
| 365 |
+
|
| 366 |
return formatted_output, info_output
|
| 367 |
|
| 368 |
|
|
|
|
| 404 |
# Combine output into a single string for the Markdown component
|
| 405 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
| 406 |
full_output += "---\n\n"
|
| 407 |
+
|
| 408 |
# Query Representation
|
| 409 |
full_output += f"Query Representation ({query_model_name_display}):\n\n"
|
| 410 |
full_output += query_main_rep_str + "\n\n" + query_info_str # Added an extra newline for better spacing
|
|
|
|
| 413 |
# Document Representation
|
| 414 |
full_output += f"Document Representation ({doc_model_name_display}):\n\n"
|
| 415 |
full_output += doc_main_rep_str + "\n\n" + doc_info_str # Added an extra newline for better spacing
|
| 416 |
+
|
| 417 |
return full_output
|
| 418 |
|
| 419 |
|
|
|
|
| 441 |
label="Enter your query or document text here:",
|
| 442 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
| 443 |
)
|
| 444 |
+
# --- NEW: Place the gr.ShareButton here ---
|
| 445 |
+
gr.ShareButton(
|
| 446 |
+
value="Share My Sparse Representation",
|
| 447 |
+
components=[input_text, model_radio], # You can specify components to share
|
| 448 |
+
visible=True # Make sure it's visible
|
| 449 |
+
)
|
| 450 |
+
# --- End New ---
|
| 451 |
info_output_display = gr.Markdown(
|
| 452 |
value="",
|
| 453 |
label="Vector Information",
|
|
|
|
| 455 |
)
|
| 456 |
with gr.Column(scale=2): # Right column for the main representation output
|
| 457 |
main_representation_output = gr.Markdown()
|
| 458 |
+
|
| 459 |
# Connect the interface elements
|
| 460 |
model_radio.change(
|
| 461 |
fn=predict_representation_explorer,
|
|
|
|
| 467 |
inputs=[model_radio, input_text],
|
| 468 |
outputs=[main_representation_output, info_output_display]
|
| 469 |
)
|
| 470 |
+
|
| 471 |
# Initial call to populate on load (optional, but good for demo)
|
| 472 |
demo.load(
|
| 473 |
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
|
| 474 |
outputs=[main_representation_output, info_output_display]
|
| 475 |
)
|
| 476 |
|
| 477 |
+
with gr.TabItem("Compare Encoders"): # Reverted to original gr.Interface setup
|
| 478 |
+
gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")
|
| 479 |
+
|
| 480 |
# Define the common model choices for cleaner code
|
| 481 |
model_choices = [
|
| 482 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
|
|
|
| 509 |
)
|
| 510 |
],
|
| 511 |
outputs=gr.Markdown(),
|
| 512 |
+
allow_flagging="never" # Keep this to keep the share button at the bottom of THIS interface
|
| 513 |
)
|
| 514 |
|
| 515 |
demo.launch()
|