SecureLLMSys commited on
Commit
c28f525
Β·
1 Parent(s): 383cea5
Files changed (2) hide show
  1. app.py +70 -6
  2. app_no_config.py +1218 -0
app.py CHANGED
@@ -82,10 +82,34 @@ current_attr = None
82
  current_model_path = None
83
  current_explanation_level = None
84
  current_api_key = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def initialize_model_and_attr():
87
  """Initialize model and attribution with default configuration"""
88
- global current_llm, current_attr, current_model_path, current_explanation_level, current_api_key
89
 
90
  try:
91
  # Check if we need to reinitialize the model
@@ -95,7 +119,7 @@ def initialize_model_and_attr():
95
 
96
  # Check if we need to update attribution
97
  need_attr_update = (current_attr is None or
98
- current_explanation_level != DEFAULT_EXPLANATION_LEVEL or
99
  need_model_update)
100
 
101
  if need_model_update:
@@ -106,15 +130,19 @@ def initialize_model_and_attr():
106
  current_api_key = effective_api_key
107
 
108
  if need_attr_update:
109
- print(f"Initializing context traceback with explanation level: {DEFAULT_EXPLANATION_LEVEL}")
 
 
 
 
110
  current_attr = AttnTraceAttribution(
111
  current_llm,
112
- explanation_level=DEFAULT_EXPLANATION_LEVEL,
113
- K=3,
114
  q=0.4,
115
  B=30
116
  )
117
- current_explanation_level = DEFAULT_EXPLANATION_LEVEL
118
 
119
  return current_llm, current_attr, None
120
 
@@ -957,6 +985,36 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
957
  '**Color Legend for Context Traceback (by ranking):** <span style="background-color: #FF4444; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Red</span> = 1st (most important) | <span style="background-color: #FF8C42; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Orange</span> = 2nd | <span style="background-color: #FFD93D; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Golden</span> = 3rd | <span style="background-color: #FFF280; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Yellow</span> = 4th-5th | <span style="background-color: #FFF9C4; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Light</span> = 6th+'
958
  )
959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
 
961
  # Top section: Wide Context box with tabs
962
  with gr.Row():
@@ -1209,6 +1267,12 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
1209
  outputs=[state, response_input_box, basic_response_box, basic_generate_error_box]
1210
  )
1211
 
 
 
 
 
 
 
1212
 
1213
  # gr.Markdown(
1214
  # "Please do not interact with elements while generation/attribution is in progress. This may cause errors. You can refresh the page if you run into issues because of this."
 
82
  current_model_path = None
83
  current_explanation_level = None
84
  current_api_key = None
85
+ current_top_k = 3 # Add top-k tracking
86
+
87
+ def update_configuration(explanation_level, top_k):
88
+ """Update the global configuration and reinitialize attribution if needed"""
89
+ global current_explanation_level, current_top_k, current_attr
90
+
91
+ # Convert top_k to int
92
+ top_k = int(top_k)
93
+
94
+ # Check if configuration has changed
95
+ config_changed = (current_explanation_level != explanation_level or
96
+ current_top_k != top_k)
97
+
98
+ if config_changed:
99
+ print(f"πŸ”„ Updating configuration: explanation_level={explanation_level}, top_k={top_k}")
100
+ current_explanation_level = explanation_level
101
+ current_top_k = top_k
102
+
103
+ # Reset attribution to force reinitialization with new config
104
+ current_attr = None
105
+
106
+ return gr.update(value=f"βœ… Configuration updated: {explanation_level} level, top-{top_k}")
107
+ else:
108
+ return gr.update(value="ℹ️ Configuration unchanged")
109
 
110
  def initialize_model_and_attr():
111
  """Initialize model and attribution with default configuration"""
112
+ global current_llm, current_attr, current_model_path, current_explanation_level, current_api_key, current_top_k
113
 
114
  try:
115
  # Check if we need to reinitialize the model
 
119
 
120
  # Check if we need to update attribution
121
  need_attr_update = (current_attr is None or
122
+ current_explanation_level != (current_explanation_level or DEFAULT_EXPLANATION_LEVEL) or
123
  need_model_update)
124
 
125
  if need_model_update:
 
130
  current_api_key = effective_api_key
131
 
132
  if need_attr_update:
133
+ # Use current configuration or defaults
134
+ explanation_level = current_explanation_level or DEFAULT_EXPLANATION_LEVEL
135
+ top_k = current_top_k or 3
136
+
137
+ print(f"Initializing context traceback with explanation level: {explanation_level}, top_k: {top_k}")
138
  current_attr = AttnTraceAttribution(
139
  current_llm,
140
+ explanation_level=explanation_level,
141
+ K=top_k,
142
  q=0.4,
143
  B=30
144
  )
145
+ current_explanation_level = explanation_level
146
 
147
  return current_llm, current_attr, None
148
 
 
985
  '**Color Legend for Context Traceback (by ranking):** <span style="background-color: #FF4444; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Red</span> = 1st (most important) | <span style="background-color: #FF8C42; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Orange</span> = 2nd | <span style="background-color: #FFD93D; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Golden</span> = 3rd | <span style="background-color: #FFF280; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Yellow</span> = 4th-5th | <span style="background-color: #FFF9C4; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Light</span> = 6th+'
986
  )
987
 
988
+ # Configuration bar
989
+ with gr.Row():
990
+ with gr.Column(scale=1):
991
+ explanation_level_dropdown = gr.Dropdown(
992
+ choices=["sentence", "paragraph", "text segment"],
993
+ value="sentence",
994
+ label="Explanation Level",
995
+ info="How to segment the context for traceback analysis"
996
+ )
997
+ with gr.Column(scale=1):
998
+ top_k_dropdown = gr.Dropdown(
999
+ choices=["3", "5", "10"],
1000
+ value="5",
1001
+ label="Top-K Value",
1002
+ info="Number of most important text segments to highlight"
1003
+ )
1004
+ with gr.Column(scale=1):
1005
+ apply_config_button = gr.Button(
1006
+ "Apply Configuration",
1007
+ variant="secondary",
1008
+ size="sm"
1009
+ )
1010
+ with gr.Column(scale=2):
1011
+ config_status_text = gr.Textbox(
1012
+ label="Configuration Status",
1013
+ value="Ready to apply configuration",
1014
+ interactive=False,
1015
+ lines=1
1016
+ )
1017
+
1018
 
1019
  # Top section: Wide Context box with tabs
1020
  with gr.Row():
 
1267
  outputs=[state, response_input_box, basic_response_box, basic_generate_error_box]
1268
  )
1269
 
1270
+ # Configuration update handler
1271
+ apply_config_button.click(
1272
+ fn=update_configuration,
1273
+ inputs=[explanation_level_dropdown, top_k_dropdown],
1274
+ outputs=[config_status_text]
1275
+ )
1276
 
1277
  # gr.Markdown(
1278
  # "Please do not interact with elements while generation/attribution is in progress. This may cause errors. You can refresh the page if you run into issues because of this."
app_no_config.py ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Acknowledgement: This demo code is adapted from the original Hugging Face Space "ContextCite"
2
+ # (https://huggingface.co/spaces/contextcite/context-cite).
3
+ import os
4
+ from enum import Enum
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Any, Optional
7
+ import gradio as gr
8
+ import numpy as np
9
+ import spaces
10
+ import nltk
11
+ import base64
12
+ import traceback
13
+ from src.utils import split_into_sentences as split_into_sentences_utils
14
+ # --- AttnTrace imports (from app_full.py) ---
15
+ from src.models import create_model
16
+ from src.attribution import AttnTraceAttribution
17
+ from src.prompts import wrap_prompt
18
+ from gradio_highlightedtextbox import HighlightedTextbox
19
+ from examples import run_example_1, run_example_2, run_example_3, run_example_4, run_example_5, run_example_6
20
+ from functools import partial
21
+ os.makedirs("/home/user/nltk_data", exist_ok=True)
22
+ # Download punkt to a known path
23
+ nltk.download("punkt", download_dir="/home/user/nltk_data")
24
+ # Tell nltk where to find it
25
+ nltk.data.path.append("/home/user/nltk_data")
26
+ from nltk.tokenize import sent_tokenize
27
+
28
+ # Load original app constants
29
+ APP_TITLE = '<div class="app-title"><span class="brand">AttnTrace: </span><span class="subtitle">Attention-based Context Traceback for Long-Context LLMs</span></div>'
30
+ APP_DESCRIPTION = """AttnTrace traces a model's generated statements back to specific parts of the context using attention-based traceback. Try it out with Meta-Llama-3.1-8B-Instruct here! See the [[paper](https://arxiv.org/abs/2506.04202)] and [[code](https://github.com/Wang-Yanting/TracLLM-Kit)] for more!
31
+ Maintained by the AttnTrace team."""
32
+ # NEW_TEXT = """Long-context large language models (LLMs), such as Gemini-2.5-Pro and Claude-Sonnet-4, are increasingly used to empower advanced AI systems, including retrieval-augmented generation (RAG) pipelines and autonomous agents. In these systems, an LLM receives an instruction along with a contextβ€”often consisting of texts retrieved from a knowledge database or memoryβ€”and generates a response that is contextually grounded by following the instruction. Recent studies have designed solutions to trace back to a subset of texts in the context that contributes most to the response generated by the LLM. These solutions have numerous real-world applications, including performing post-attack forensic analysis and improving the interpretability and trustworthiness of LLM outputs. While significant efforts have been made, state-of-the-art solutions such as TracLLM often lead to a high computation cost, e.g., it takes TracLLM hundreds of seconds to perform traceback for a single response-context pair. In this work, we propose {\name}, a new context traceback method based on the attention weights produced by an LLM for a prompt. To effectively utilize attention weights, we introduce two techniques designed to enhance the effectiveness of {\name}, and we provide theoretical insights for our design choice. %Moreover, we perform both theoretical analysis and empirical evaluation to demonstrate their effectiveness.
33
+ # We also perform a systematic evaluation for {\name}. The results demonstrate that {\name} is more accurate and efficient than existing state-of-the-art context traceback methods. We also show {\name} can improve state-of-the-art methods in detecting prompt injection under long contexts through the attribution-before-detection paradigm. As a real-world application, we demonstrate that {\name} can effectively pinpoint injected instructions in a paper designed to manipulate LLM-generated reviews.
34
+ # The code and data will be open-sourced. """
35
+ # EDIT_TEXT = "Feel free to edit!"
36
+ GENERATE_CONTEXT_TOO_LONG_TEXT = (
37
+ '<em style="color: red;">Context is too long for the current model.</em>'
38
+ )
39
+ ATTRIBUTE_CONTEXT_TOO_LONG_TEXT = '<em style="color: red;">Context is too long for the current traceback method.</em>'
40
+ CONTEXT_LINES = 20
41
+ CONTEXT_MAX_LINES = 40
42
+ SELECTION_DEFAULT_TEXT = "Click on a sentence in the response to traceback!"
43
+ SELECTION_DEFAULT_VALUE = [(SELECTION_DEFAULT_TEXT, None)]
44
+ SOURCES_INFO = 'These are the texts that contribute most to the response.'
45
+ # SOURCES_IN_CONTEXT_INFO = (
46
+ # "This shows the important sentences highlighted within their surrounding context from the text above. Colors indicate ranking: Red (1st), Orange (2nd), Golden (3rd), Yellow (4th-5th), Light (6th+)."
47
+ # )
48
+
49
+ MODEL_PATHS = [
50
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
51
+ ]
52
+ MAX_TOKENS = {
53
+ "meta-llama/Meta-Llama-3.1-8B-Instruct": 131072,
54
+ }
55
+ DEFAULT_MODEL_PATH = MODEL_PATHS[0]
56
+ EXPLANATION_LEVELS = ["sentence", "paragraph", "text segment"]
57
+ DEFAULT_EXPLANATION_LEVEL = "sentence"
58
+
59
+ class WorkflowState(Enum):
60
+ WAITING_TO_GENERATE = 0
61
+ WAITING_TO_SELECT = 1
62
+ READY_TO_ATTRIBUTE = 2
63
+
64
+ @dataclass
65
+ class State:
66
+ workflow_state: WorkflowState
67
+ context: str
68
+ query: str
69
+ response: str
70
+ start_index: int
71
+ end_index: int
72
+ scores: np.ndarray
73
+ answer: str
74
+ highlighted_context: str
75
+ full_response: str
76
+ explained_response_part: str
77
+ last_query_used: str = ""
78
+
79
+ # --- Dynamic Model and Attribution Management ---
80
+ current_llm = None
81
+ current_attr = None
82
+ current_model_path = None
83
+ current_explanation_level = None
84
+ current_api_key = None
85
+
86
+ def initialize_model_and_attr():
87
+ """Initialize model and attribution with default configuration"""
88
+ global current_llm, current_attr, current_model_path, current_explanation_level, current_api_key
89
+
90
+ try:
91
+ # Check if we need to reinitialize the model
92
+ need_model_update = (current_llm is None or
93
+ current_model_path != DEFAULT_MODEL_PATH or
94
+ current_api_key != os.getenv("HF_TOKEN"))
95
+
96
+ # Check if we need to update attribution
97
+ need_attr_update = (current_attr is None or
98
+ current_explanation_level != DEFAULT_EXPLANATION_LEVEL or
99
+ need_model_update)
100
+
101
+ if need_model_update:
102
+ print(f"Initializing model: {DEFAULT_MODEL_PATH}")
103
+ effective_api_key = os.getenv("HF_TOKEN")
104
+ current_llm = create_model(model_path=DEFAULT_MODEL_PATH, api_key=effective_api_key, device="cuda")
105
+ current_model_path = DEFAULT_MODEL_PATH
106
+ current_api_key = effective_api_key
107
+
108
+ if need_attr_update:
109
+ print(f"Initializing context traceback with explanation level: {DEFAULT_EXPLANATION_LEVEL}")
110
+ current_attr = AttnTraceAttribution(
111
+ current_llm,
112
+ explanation_level=DEFAULT_EXPLANATION_LEVEL,
113
+ K=3,
114
+ q=0.4,
115
+ B=30
116
+ )
117
+ current_explanation_level = DEFAULT_EXPLANATION_LEVEL
118
+
119
+ return current_llm, current_attr, None
120
+
121
+ except Exception as e:
122
+ error_msg = f"Error initializing model/traceback: {str(e)}"
123
+ print(error_msg)
124
+ traceback.print_exc()
125
+ return None, None, error_msg
126
+
127
+ # Remove immediate initialization - let lazy initialization work
128
+ llm, attr, error_msg = initialize_model_and_attr() # Commented out to avoid main-thread CUDA initialization
129
+
130
+ # Images replaced with CSS textures and gradients - no longer needed
131
+
132
+ def clear_state():
133
+ return State(
134
+ workflow_state=WorkflowState.WAITING_TO_GENERATE,
135
+ context="",
136
+ query="",
137
+ response="",
138
+ start_index=0,
139
+ end_index=0,
140
+ scores=np.array([]),
141
+ answer="",
142
+ highlighted_context="",
143
+ full_response="",
144
+ explained_response_part="",
145
+ last_query_used=""
146
+ )
147
+
148
+ def load_an_example(example_loader_func, state: State):
149
+ context, query = example_loader_func()
150
+ # Update both UI and state
151
+ state.context = context
152
+ state.query = query
153
+ state.workflow_state = WorkflowState.WAITING_TO_GENERATE
154
+ # Clear previous results
155
+ state.response = ""
156
+ state.answer = ""
157
+ state.full_response = ""
158
+ state.explained_response_part = ""
159
+ print(f"Loaded example - Context: {len(context)} chars, Query: {query[:50]}...")
160
+ return (
161
+ context, # basic_context_box
162
+ query, # basic_query_box
163
+ state,
164
+ "", # response_input_box - clear it
165
+ gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]), # basic_response_box - keep visible
166
+ gr.update(selected=0) # basic_context_tabs - switch to first tab
167
+ )
168
+
169
+
170
+ def get_max_tokens(model_path: str):
171
+ return MAX_TOKENS.get(model_path, 2048) # Default fallback
172
+
173
+
174
+ def get_scroll_js_code(elem_id):
175
+ return f"""
176
+ function scrollToElement() {{
177
+ const element = document.getElementById("{elem_id}");
178
+ element.scrollIntoView({{ behavior: "smooth", block: "nearest" }});
179
+ }}
180
+ """
181
+
182
+ def basic_update(context: str, query: str, state: State):
183
+ state.context = context
184
+ state.query = query
185
+ state.workflow_state = WorkflowState.WAITING_TO_GENERATE
186
+ return (
187
+ gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]), # basic_response_box - keep visible
188
+ gr.update(selected=0), # basic_context_tabs - switch to first tab
189
+ state,
190
+ )
191
+
192
+
193
+
194
+
195
+
196
+ @spaces.GPU
197
+ def generate_model_response(state: State):
198
+ # Validate inputs first with debug info
199
+ print(f"Validation - Context length: {len(state.context) if state.context else 0}")
200
+ print(f"Validation - Query: {state.query[:50] if state.query else 'empty'}...")
201
+
202
+ if not state.context or not state.context.strip():
203
+ print("❌ Validation failed: No context")
204
+ return state, gr.update(value=[("❌ Please enter context before generating response! If you just changed configuration, try reloading an example.", None)], visible=True)
205
+
206
+ if not state.query or not state.query.strip():
207
+ print("❌ Validation failed: No query")
208
+ return state, gr.update(value=[("❌ Please enter a query before generating response! If you just changed configuration, try reloading an example.", None)], visible=True)
209
+
210
+ # Initialize model and attribution with default configuration
211
+ print(f"πŸ”§ Generating response with explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
212
+ #llm, attr, error_msg = initialize_model_and_attr()
213
+
214
+ if llm is None or attr is None:
215
+ error_text = error_msg if error_msg else "Model initialization failed!"
216
+ return state, gr.update(value=[(f"❌ {error_text}", None)], visible=True)
217
+
218
+ prompt = wrap_prompt(state.query, [state.context])
219
+ print(f"Generated prompt for {DEFAULT_MODEL_PATH}: {prompt[:200]}...") # Debug log
220
+
221
+ # Check context length
222
+ if len(prompt.split()) > get_max_tokens(DEFAULT_MODEL_PATH) - 512:
223
+ return state, gr.update(value=[(GENERATE_CONTEXT_TOO_LONG_TEXT, None)], visible=True)
224
+
225
+ answer = llm.query(prompt)
226
+ print(f"Model response: {answer}") # Debug log
227
+
228
+ state.response = answer
229
+ state.answer = answer
230
+ state.full_response = answer
231
+ state.workflow_state = WorkflowState.WAITING_TO_SELECT
232
+ return state, gr.update(visible=False)
233
+
234
+ def split_into_sentences(text: str):
235
+ def rule_based_split(text):
236
+ sentences = []
237
+ start = 0
238
+ for i, char in enumerate(text):
239
+ if char in ".?。":
240
+ if i + 1 == len(text) or text[i + 1] == " ":
241
+ sentences.append(text[start:i + 1].strip())
242
+ start = i + 1
243
+ if start < len(text):
244
+ sentences.append(text[start:].strip())
245
+ return sentences
246
+
247
+ lines = text.splitlines()
248
+ sentences = []
249
+ for line in lines:
250
+ #sentences.extend(sent_tokenize(line))
251
+ sentences.extend(rule_based_split(line))
252
+ separators = []
253
+ cur_start = 0
254
+ for sentence in sentences:
255
+ cur_end = text.find(sentence, cur_start)
256
+ separators.append(text[cur_start:cur_end])
257
+ cur_start = cur_end + len(sentence)
258
+ return sentences, separators
259
+
260
+
261
+ def basic_highlight_response(
262
+ response: str, selected_index: int, num_sources: int = -1
263
+ ):
264
+ sentences, separators = split_into_sentences(response)
265
+ ht = []
266
+ if num_sources == -1:
267
+ citations_text = "Traceback!"
268
+ elif num_sources == 0:
269
+ citations_text = "No important text!"
270
+ else:
271
+ citations_text = f"[{','.join(str(i) for i in range(1, num_sources + 1))}]"
272
+ for i, (sentence, separator) in enumerate(zip(sentences, separators)):
273
+ label = citations_text if i == selected_index else "Traceback"
274
+ # Hack to ignore punctuation
275
+ if len(sentence) >= 4:
276
+ ht.append((separator + sentence, label))
277
+ else:
278
+ ht.append((separator + sentence, None))
279
+ color_map = {"Click to cite!": "blue", citations_text: "yellow"}
280
+ return gr.HighlightedText(value=ht, color_map=color_map)
281
+
282
+ def basic_highlight_response_with_visibility(
283
+ response: str, selected_index: int, num_sources: int = -1, visible: bool = True
284
+ ):
285
+ """Version of basic_highlight_response that also sets visibility"""
286
+ sentences, separators = split_into_sentences(response)
287
+ ht = []
288
+ if num_sources == -1:
289
+ citations_text = "Traceback!"
290
+ elif num_sources == 0:
291
+ citations_text = "No important text!"
292
+ else:
293
+ citations_text = f"[{','.join(str(i) for i in range(1, num_sources + 1))}]"
294
+ for i, (sentence, separator) in enumerate(zip(sentences, separators)):
295
+ label = citations_text if i == selected_index else "Traceback"
296
+ # Hack to ignore punctuation
297
+ if len(sentence) >= 4:
298
+ ht.append((separator + sentence, label))
299
+ else:
300
+ ht.append((separator + sentence, None))
301
+ color_map = {"Click to cite!": "blue", citations_text: "yellow"}
302
+ return gr.update(value=ht, color_map=color_map, visible=visible)
303
+
304
+
305
+
306
+ def basic_update_highlighted_response(evt: gr.SelectData, state: State):
307
+ response_update = basic_highlight_response(state.response, evt.index)
308
+ return response_update, state
309
+
310
+ def unified_response_handler(response_text: str, state: State):
311
+ """Handle both LLM generation and manual input based on whether text is provided"""
312
+
313
+ # Check if instruction has changed from what was used to generate current response
314
+ instruction_changed = hasattr(state, 'last_query_used') and state.last_query_used != state.query
315
+
316
+ # If response_text is empty, whitespace, or instruction changed, generate from LLM
317
+ if not response_text or not response_text.strip() or instruction_changed:
318
+ if instruction_changed:
319
+ print("πŸ“ Instruction changed, generating new response from LLM...")
320
+ else:
321
+ print("πŸ€– Generating response from LLM...")
322
+
323
+ # Validate inputs first
324
+ if not state.context or not state.context.strip():
325
+ return (
326
+ state,
327
+ response_text, # Keep current text box content
328
+ gr.update(visible=False), # Keep response box hidden
329
+ gr.update(value=[("❌ Please enter context before generating response!", None)], visible=True)
330
+ )
331
+
332
+ if not state.query or not state.query.strip():
333
+ return (
334
+ state,
335
+ response_text, # Keep current text box content
336
+ gr.update(visible=False), # Keep response box hidden
337
+ gr.update(value=[("❌ Please enter a query before generating response!", None)], visible=True)
338
+ )
339
+
340
+ # Initialize model and generate response
341
+ #llm, attr, error_msg = initialize_model_and_attr()
342
+
343
+ if llm is None:
344
+ error_text = error_msg if error_msg else "Model initialization failed!"
345
+ return (
346
+ state,
347
+ response_text, # Keep current text box content
348
+ gr.update(visible=False), # Keep response box hidden
349
+ gr.update(value=[(f"❌ {error_text}", None)], visible=True)
350
+ )
351
+
352
+ prompt = wrap_prompt(state.query, [state.context])
353
+
354
+ # Check context length
355
+ if len(prompt.split()) > get_max_tokens(DEFAULT_MODEL_PATH) - 512:
356
+ return (
357
+ state,
358
+ response_text, # Keep current text box content
359
+ gr.update(visible=False), # Keep response box hidden
360
+ gr.update(value=[(GENERATE_CONTEXT_TOO_LONG_TEXT, None)], visible=True)
361
+ )
362
+
363
+ # Generate response
364
+ answer = llm.query(prompt)
365
+ print(f"Generated response: {answer[:100]}...")
366
+
367
+ # Update state and UI
368
+ state.response = answer
369
+ state.answer = answer
370
+ state.full_response = answer
371
+ state.last_query_used = state.query # Track which query was used for this response
372
+ state.workflow_state = WorkflowState.WAITING_TO_SELECT
373
+
374
+ # Create highlighted response and show it
375
+ response_update = basic_highlight_response_with_visibility(state.response, -1, visible=True)
376
+
377
+ return (
378
+ state,
379
+ answer, # Put generated response in text box
380
+ response_update, # Update clickable response content
381
+ gr.update(visible=False) # Hide error box
382
+ )
383
+
384
+ else:
385
+ # Use provided text as manual response
386
+ print("✏️ Using manual response...")
387
+ manual_text = response_text.strip()
388
+
389
+ # Update state with manual response
390
+ state.response = manual_text
391
+ state.answer = manual_text
392
+ state.full_response = manual_text
393
+ state.last_query_used = state.query # Track current query for this response
394
+ state.workflow_state = WorkflowState.WAITING_TO_SELECT
395
+
396
+ # Create highlighted response for selection
397
+ response_update = basic_highlight_response_with_visibility(state.response, -1, visible=True)
398
+
399
+ return (
400
+ state,
401
+ manual_text, # Keep text in text box
402
+ response_update, # Update clickable response content
403
+ gr.update(visible=False) # Hide error box
404
+ )
405
+
406
+ def get_color_by_rank(rank, total_items):
407
+ """Get color based purely on rank position for better visual distinction"""
408
+ if total_items == 0:
409
+ return "#F0F0F0", "rgba(240, 240, 240, 0.8)"
410
+
411
+ # Pure ranking-based color assignment for clear visual hierarchy
412
+ if rank == 1: # Highest importance - Strong Red
413
+ bg_color = "#FF4444" # Bright red
414
+ rgba_color = "rgba(255, 68, 68, 0.9)"
415
+ elif rank == 2: # Second highest - Orange
416
+ bg_color = "#FF8C42" # Bright orange
417
+ rgba_color = "rgba(255, 140, 66, 0.8)"
418
+ elif rank == 3: # Third highest - Golden Yellow
419
+ bg_color = "#FFD93D" # Golden yellow
420
+ rgba_color = "rgba(255, 217, 61, 0.8)"
421
+ elif rank <= 5: # 4th-5th - Light Yellow
422
+ bg_color = "#FFF280" # Standard yellow
423
+ rgba_color = "rgba(255, 242, 128, 0.7)"
424
+ else: # Lower importance - Very Light Yellow
425
+ bg_color = "#FFF9C4" # Very light yellow
426
+ rgba_color = "rgba(255, 249, 196, 0.6)"
427
+
428
+ return bg_color, rgba_color
429
+
430
+ @spaces.GPU
431
+ def basic_get_scores_and_sources_full_response(state: State):
432
+ """Traceback the entire response instead of a selected segment"""
433
+
434
+
435
+ # Use the entire response as the explained part
436
+ state.explained_response_part = state.full_response
437
+
438
+ # Attribution using default configuration
439
+ #_, attr, error_msg = initialize_model_and_attr()
440
+
441
+ if attr is None:
442
+ error_text = error_msg if error_msg else "Traceback initialization failed!"
443
+ return (
444
+ gr.update(value=[("", None)], visible=False),
445
+ gr.update(selected=0),
446
+ gr.update(visible=False),
447
+ gr.update(value=""),
448
+ gr.update(value=[(f"❌ {error_text}", None)], visible=True),
449
+ state,
450
+ )
451
+ try:
452
+ # Validate attribution inputs
453
+ if not state.context or not state.context.strip():
454
+ return (
455
+ gr.update(value=[("", None)], visible=False),
456
+ gr.update(selected=0),
457
+ gr.update(visible=False),
458
+ gr.update(value=""),
459
+ gr.update(value=[("❌ No context available for traceback!", None)], visible=True),
460
+ state,
461
+ )
462
+
463
+ if not state.query or not state.query.strip():
464
+ return (
465
+ gr.update(value=[("", None)], visible=False),
466
+ gr.update(selected=0),
467
+ gr.update(visible=False),
468
+ gr.update(value=""),
469
+ gr.update(value=[("❌ No query available for traceback!", None)], visible=True),
470
+ state,
471
+ )
472
+
473
+ if not state.full_response or not state.full_response.strip():
474
+ return (
475
+ gr.update(value=[("", None)], visible=False),
476
+ gr.update(selected=0),
477
+ gr.update(visible=False),
478
+ gr.update(value=""),
479
+ gr.update(value=[("❌ No response available for traceback!", None)], visible=True),
480
+ state,
481
+ )
482
+
483
+ print(f"start full response traceback with explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
484
+ print(f"context length: {len(state.context)}, query: {state.query[:100]}...")
485
+ print(f"full response: {state.full_response[:100]}...")
486
+ print(f"tracing entire response (length: {len(state.full_response)} chars)")
487
+
488
+ texts, important_ids, importance_scores, _, _ = attr.attribute(
489
+ state.query, [state.context], state.full_response, state.full_response
490
+ )
491
+ print("end full response traceback")
492
+ print(f"explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
493
+ print(f"texts count: {len(texts)} (how context was segmented)")
494
+ if len(texts) > 0:
495
+ print(f"sample text segments: {[text[:50] + '...' if len(text) > 50 else text for text in texts[:3]]}")
496
+ print(f"important_ids: {important_ids}")
497
+ print("importance_scores: ", importance_scores)
498
+
499
+ if not importance_scores:
500
+ return (
501
+ gr.update(value=[("", None)], visible=False),
502
+ gr.update(selected=0),
503
+ gr.update(visible=False),
504
+ gr.update(value=""),
505
+ gr.update(value=[("❌ No traceback scores generated for full response!", None)], visible=True),
506
+ state,
507
+ )
508
+
509
+ state.scores = np.array(importance_scores)
510
+
511
+ # Highlighted sources with ranking-based colors
512
+ highlighted_text = []
513
+ sorted_indices = np.argsort(state.scores)[::-1]
514
+ total_sources = len(important_ids)
515
+
516
+ for rank, i in enumerate(sorted_indices):
517
+ source_text = texts[important_ids[i]]
518
+ _ = get_color_by_rank(rank + 1, total_sources)
519
+
520
+ highlighted_text.append(
521
+ (
522
+ source_text,
523
+ f"rank_{rank+1}",
524
+ )
525
+ )
526
+
527
+ # In-context highlights with ranking-based colors - show ALL text
528
+ in_context_highlighted_text = []
529
+ ranks = {important_ids[i]: rank for rank, i in enumerate(sorted_indices)}
530
+
531
+ for i in range(len(texts)):
532
+ source_text = texts[i]
533
+
534
+ # Skip or don't highlight segments that are only newlines or whitespace
535
+ if source_text.strip() == "":
536
+ # For whitespace-only segments, add them without highlighting
537
+ in_context_highlighted_text.append((source_text, None))
538
+ elif i in important_ids:
539
+ # Only highlight if the segment has actual content (not just newlines)
540
+ if source_text.strip(): # Has non-whitespace content
541
+ rank = ranks[i] + 1
542
+
543
+ # Split the segment to separate leading/trailing newlines from content
544
+ # This prevents newlines from being highlighted
545
+ leading_whitespace = ""
546
+ trailing_whitespace = ""
547
+ content = source_text
548
+
549
+ # Extract leading newlines/whitespace
550
+ while content and content[0] in ['\n', '\r', '\t', ' ']:
551
+ leading_whitespace += content[0]
552
+ content = content[1:]
553
+
554
+ # Extract trailing newlines/whitespace
555
+ while content and content[-1] in ['\n', '\r', '\t', ' ']:
556
+ trailing_whitespace = content[-1] + trailing_whitespace
557
+ content = content[:-1]
558
+
559
+ # Add the parts separately: whitespace unhighlighted, content highlighted
560
+ if leading_whitespace:
561
+ in_context_highlighted_text.append((leading_whitespace, None))
562
+ if content:
563
+ in_context_highlighted_text.append((content, f"rank_{rank}"))
564
+ if trailing_whitespace:
565
+ in_context_highlighted_text.append((trailing_whitespace, None))
566
+ else:
567
+ # Even if marked as important, don't highlight whitespace-only segments
568
+ in_context_highlighted_text.append((source_text, None))
569
+ else:
570
+ # Add unhighlighted text for non-important segments
571
+ in_context_highlighted_text.append((source_text, None))
572
+
573
+ # Enhanced color map with ranking-based colors
574
+ color_map = {}
575
+ for rank in range(len(important_ids)):
576
+ _, rgba_color = get_color_by_rank(rank + 1, total_sources)
577
+ color_map[f"rank_{rank+1}"] = rgba_color
578
+ dummy_update = gr.update(
579
+ value=f"AttnTrace_{state.response}_{state.start_index}_{state.end_index}"
580
+ )
581
+ attribute_error_update = gr.update(visible=False)
582
+
583
+ # Combine sources and highlighted context into a single display
584
+ # Sources at the top
585
+ combined_display = []
586
+
587
+ # Add sources header (no highlighting for UI elements)
588
+ combined_display.append(("═══ FULL RESPONSE TRACEBACK RESULTS ═══\n", None))
589
+ combined_display.append(("These are the text segments that contribute most to the entire response:\n\n", None))
590
+
591
+ # Add sources using available data
592
+ for rank, i in enumerate(sorted_indices):
593
+ if i < len(important_ids):
594
+ source_text = texts[important_ids[i]]
595
+
596
+ # Strip leading/trailing whitespace from source text to avoid highlighting newlines
597
+ clean_source_text = source_text.strip()
598
+
599
+ if clean_source_text: # Only add if there's actual content
600
+ # Add the source text with highlighting, then add spacing without highlighting
601
+ combined_display.append((clean_source_text, f"rank_{rank+1}"))
602
+ combined_display.append(("\n\n", None))
603
+
604
+ # Add separator (no highlighting for UI elements)
605
+ combined_display.append(("\n" + "═"*50 + "\n", None))
606
+ combined_display.append(("FULL CONTEXT WITH HIGHLIGHTS\n", None))
607
+ combined_display.append(("Scroll down to see the complete context with important segments highlighted:\n\n", None))
608
+
609
+ # Add highlighted context using in_context_highlighted_text
610
+ combined_display.extend(in_context_highlighted_text)
611
+
612
+ # Use only the ranking colors (no highlighting for UI elements)
613
+ enhanced_color_map = color_map.copy()
614
+
615
+ combined_sources_update = HighlightedTextbox(
616
+ value=combined_display, color_map=enhanced_color_map, visible=True
617
+ )
618
+
619
+ # Switch to the highlighted context tab and show results
620
+ basic_context_tabs_update = gr.update(selected=1)
621
+ basic_sources_in_context_tab_update = gr.update(visible=True)
622
+
623
+ return (
624
+ combined_sources_update,
625
+ basic_context_tabs_update,
626
+ basic_sources_in_context_tab_update,
627
+ dummy_update,
628
+ attribute_error_update,
629
+ state,
630
+ )
631
+ except Exception as e:
632
+ traceback.print_exc()
633
+ return (
634
+ gr.update(value=[("", None)], visible=False),
635
+ gr.update(selected=0),
636
+ gr.update(visible=False),
637
+ gr.update(value=""),
638
+ gr.update(value=[(f"❌ Error: {str(e)}", None)], visible=True),
639
+ state,
640
+ )
641
+
642
+ def basic_get_scores_and_sources(
643
+ evt: gr.SelectData,
644
+ highlighted_response: List[Dict[str, str]],
645
+ state: State,
646
+ ):
647
+
648
+ # Get the selected sentence
649
+ print("highlighted_response: ", highlighted_response[evt.index])
650
+ selected_text = highlighted_response[evt.index]['token']
651
+ state.explained_response_part = selected_text
652
+
653
+ # Attribution using default configuration
654
+ #_, attr, error_msg = initialize_model_and_attr()
655
+
656
+ if attr is None:
657
+ error_text = error_msg if error_msg else "Traceback initialization failed!"
658
+ return (
659
+ gr.update(value=[("", None)], visible=False),
660
+ gr.update(selected=0),
661
+ gr.update(visible=False),
662
+ gr.update(value=""),
663
+ gr.update(value=[(f"❌ {error_text}", None)], visible=True),
664
+ state,
665
+ )
666
+ try:
667
+ # Validate attribution inputs
668
+ if not state.context or not state.context.strip():
669
+ return (
670
+ gr.update(value=[("", None)], visible=False),
671
+ gr.update(selected=0),
672
+ gr.update(visible=False),
673
+ gr.update(value=""),
674
+ gr.update(value=[("❌ No context available for traceback!", None)], visible=True),
675
+ state,
676
+ )
677
+
678
+ if not state.query or not state.query.strip():
679
+ return (
680
+ gr.update(value=[("", None)], visible=False),
681
+ gr.update(selected=0),
682
+ gr.update(visible=False),
683
+ gr.update(value=""),
684
+ gr.update(value=[("❌ No query available for traceback!", None)], visible=True),
685
+ state,
686
+ )
687
+
688
+ if not state.full_response or not state.full_response.strip():
689
+ return (
690
+ gr.update(value=[("", None)], visible=False),
691
+ gr.update(selected=0),
692
+ gr.update(visible=False),
693
+ gr.update(value=""),
694
+ gr.update(value=[("❌ No response available for traceback!", None)], visible=True),
695
+ state,
696
+ )
697
+
698
+ print(f"start traceback with explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
699
+ print(f"context length: {len(state.context)}, query: {state.query[:100]}...")
700
+ print(f"response: {state.full_response[:100]}...")
701
+ print(f"selected part: {state.explained_response_part[:100]}...")
702
+
703
+ texts, important_ids, importance_scores, _, _ = attr.attribute(
704
+ state.query, [state.context], state.full_response, state.explained_response_part
705
+ )
706
+ print("end traceback")
707
+ print(f"explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
708
+ print(f"texts count: {len(texts)} (how context was segmented)")
709
+ if len(texts) > 0:
710
+ print(f"sample text segments: {[text[:50] + '...' if len(text) > 50 else text for text in texts[:3]]}")
711
+ print(f"important_ids: {important_ids}")
712
+ print("importance_scores: ", importance_scores)
713
+
714
+ if not importance_scores:
715
+ return (
716
+ gr.update(value=[("", None)], visible=False),
717
+ gr.update(selected=0),
718
+ gr.update(visible=False),
719
+ gr.update(value=""),
720
+ gr.update(value=[("❌ No traceback scores generated! Try a different text segment.", None)], visible=True),
721
+ state,
722
+ )
723
+
724
+ state.scores = np.array(importance_scores)
725
+
726
+ # Highlighted sources with ranking-based colors
727
+ highlighted_text = []
728
+ sorted_indices = np.argsort(state.scores)[::-1]
729
+ total_sources = len(important_ids)
730
+
731
+ for rank, i in enumerate(sorted_indices):
732
+ source_text = texts[important_ids[i]]
733
+ _ = get_color_by_rank(rank + 1, total_sources)
734
+
735
+ highlighted_text.append(
736
+ (
737
+ source_text,
738
+ f"rank_{rank+1}",
739
+ )
740
+ )
741
+
742
+ # In-context highlights with ranking-based colors - show ALL text
743
+ in_context_highlighted_text = []
744
+ ranks = {important_ids[i]: rank for rank, i in enumerate(sorted_indices)}
745
+
746
+ for i in range(len(texts)):
747
+ source_text = texts[i]
748
+
749
+ # Skip or don't highlight segments that are only newlines or whitespace
750
+ if source_text.strip() == "":
751
+ # For whitespace-only segments, add them without highlighting
752
+ in_context_highlighted_text.append((source_text, None))
753
+ elif i in important_ids:
754
+ # Only highlight if the segment has actual content (not just newlines)
755
+ if source_text.strip(): # Has non-whitespace content
756
+ rank = ranks[i] + 1
757
+
758
+ # Split the segment to separate leading/trailing newlines from content
759
+ # This prevents newlines from being highlighted
760
+ leading_whitespace = ""
761
+ trailing_whitespace = ""
762
+ content = source_text
763
+
764
+ # Extract leading newlines/whitespace
765
+ while content and content[0] in ['\n', '\r', '\t', ' ']:
766
+ leading_whitespace += content[0]
767
+ content = content[1:]
768
+
769
+ # Extract trailing newlines/whitespace
770
+ while content and content[-1] in ['\n', '\r', '\t', ' ']:
771
+ trailing_whitespace = content[-1] + trailing_whitespace
772
+ content = content[:-1]
773
+
774
+ # Add the parts separately: whitespace unhighlighted, content highlighted
775
+ if leading_whitespace:
776
+ in_context_highlighted_text.append((leading_whitespace, None))
777
+ if content:
778
+ in_context_highlighted_text.append((content, f"rank_{rank}"))
779
+ if trailing_whitespace:
780
+ in_context_highlighted_text.append((trailing_whitespace, None))
781
+ else:
782
+ # Even if marked as important, don't highlight whitespace-only segments
783
+ in_context_highlighted_text.append((source_text, None))
784
+ else:
785
+ # Add unhighlighted text for non-important segments
786
+ in_context_highlighted_text.append((source_text, None))
787
+
788
+ # Enhanced color map with ranking-based colors
789
+ color_map = {}
790
+ for rank in range(len(important_ids)):
791
+ _, rgba_color = get_color_by_rank(rank + 1, total_sources)
792
+ color_map[f"rank_{rank+1}"] = rgba_color
793
+ dummy_update = gr.update(
794
+ value=f"AttnTrace_{state.response}_{state.start_index}_{state.end_index}"
795
+ )
796
+ attribute_error_update = gr.update(visible=False)
797
+
798
+ # Combine sources and highlighted context into a single display
799
+ # Sources at the top
800
+ combined_display = []
801
+
802
+ # Add sources header (no highlighting for UI elements)
803
+ combined_display.append(("═══ TRACEBACK RESULTS ═══\n", None))
804
+ combined_display.append(("These are the text segments that contribute most to the response:\n\n", None))
805
+
806
+ # Add sources using available data
807
+ for rank, i in enumerate(sorted_indices):
808
+ if i < len(important_ids):
809
+ source_text = texts[important_ids[i]]
810
+
811
+ # Strip leading/trailing whitespace from source text to avoid highlighting newlines
812
+ clean_source_text = source_text.strip()
813
+
814
+ if clean_source_text: # Only add if there's actual content
815
+ # Add the source text with highlighting, then add spacing without highlighting
816
+ combined_display.append((clean_source_text, f"rank_{rank+1}"))
817
+ combined_display.append(("\n\n", None))
818
+
819
+ # Add separator (no highlighting for UI elements)
820
+ combined_display.append(("\n" + "═"*50 + "\n", None))
821
+ combined_display.append(("FULL CONTEXT WITH HIGHLIGHTS\n", None))
822
+ combined_display.append(("Scroll down to see the complete context with important segments highlighted:\n\n", None))
823
+
824
+ # Add highlighted context using in_context_highlighted_text
825
+ combined_display.extend(in_context_highlighted_text)
826
+
827
+ # Use only the ranking colors (no highlighting for UI elements)
828
+ enhanced_color_map = color_map.copy()
829
+
830
+ combined_sources_update = HighlightedTextbox(
831
+ value=combined_display, color_map=enhanced_color_map, visible=True
832
+ )
833
+
834
+ # Switch to the highlighted context tab and show results
835
+ basic_context_tabs_update = gr.update(selected=1)
836
+ basic_sources_in_context_tab_update = gr.update(visible=True)
837
+
838
+ return (
839
+ combined_sources_update,
840
+ basic_context_tabs_update,
841
+ basic_sources_in_context_tab_update,
842
+ dummy_update,
843
+ attribute_error_update,
844
+ state,
845
+ )
846
+ except Exception as e:
847
+ traceback.print_exc()
848
+ return (
849
+ gr.update(value=[("", None)], visible=False),
850
+ gr.update(selected=0),
851
+ gr.update(visible=False),
852
+ gr.update(value=""),
853
+ gr.update(value=[(f"❌ Error: {str(e)}", None)], visible=True),
854
+ state,
855
+ )
856
+
857
+ def load_custom_css():
858
+ """Load CSS from external file"""
859
+ try:
860
+ with open("assets/app_styles.css", "r") as f:
861
+ css_content = f.read()
862
+ return css_content
863
+ except FileNotFoundError:
864
+ print("Warning: CSS file not found, using minimal CSS")
865
+ return ""
866
+ except Exception as e:
867
+ print(f"Error loading CSS: {e}")
868
+ return ""
869
+
870
+ # Load CSS from external file
871
+ custom_css = load_custom_css()
872
+ theme = gr.themes.Citrus(
873
+ text_size="lg",
874
+ spacing_size="md",
875
+ )
876
+ with gr.Blocks(theme=theme, css=custom_css) as demo:
877
+ gr.Markdown(f"# {APP_TITLE}")
878
+ gr.Markdown(APP_DESCRIPTION, elem_classes="app-description")
879
+ # gr.Markdown(NEW_TEXT, elem_classes="app-description-2")
880
+
881
+ gr.Markdown("""
882
+ <div style="font-size: 18px;">
883
+ AttnTrace is an efficient context traceback method for long contexts (e.g., full papers). It is over 15Γ— faster than the state-of-the-art context traceback method TracLLM. Compared to previous attention-based approaches, AttnTrace is more accurate, reliable, and memory-efficient.
884
+ """, elem_classes="feature-highlights")
885
+ # Feature highlights
886
+ gr.Markdown("""
887
+ <div style="font-size: 18px;">
888
+ AttnTrace can be used in many real-world applications, such as tracing back to:
889
+
890
+ - πŸ“„ prompt injection instructions that manipulate LLM-generated paper reviews.
891
+ - πŸ’» malicious comment & code hiding in the codebase that misleads the AI coding assistant.
892
+ - πŸ€– malicious instructions that mislead the action of the LLM agent.
893
+ - πŸ–‹ source texts in the context from an AI summary.
894
+ - πŸ” evidence that supports the LLM-generated answer for a question.
895
+ - ❌ misinformation (corrupted knowledge) that manipulates LLM output for a question.
896
+ - And a lot more...
897
+
898
+ </div>
899
+ """, elem_classes="feature-highlights")
900
+
901
+ # Example buttons with topic-relevant images - moved here for better positioning
902
+ gr.Markdown("### πŸš€ Try These Examples!", elem_classes="example-title")
903
+ with gr.Row(elem_classes=["example-button-container"]):
904
+ with gr.Column(scale=1):
905
+ example_1_btn = gr.Button(
906
+ "πŸ“„ Prompt Injection Attacks in AI Paper Review",
907
+ elem_classes=["example-button", "example-paper"],
908
+ elem_id="example_1_button",
909
+ scale=None,
910
+ size="sm"
911
+ )
912
+ with gr.Column(scale=1):
913
+ example_2_btn = gr.Button(
914
+ "πŸ’» Malicious Comments & Code in Codebase",
915
+ elem_classes=["example-button", "example-movie"],
916
+ elem_id="example_2_button"
917
+ )
918
+ with gr.Column(scale=1):
919
+ example_3_btn = gr.Button(
920
+ "πŸ€– Malicious Instructions Misleading the LLM Agent",
921
+ elem_classes=["example-button", "example-code"],
922
+ elem_id="example_3_button"
923
+ )
924
+
925
+ with gr.Row(elem_classes=["example-button-container"]):
926
+ with gr.Column(scale=1):
927
+ example_4_btn = gr.Button(
928
+ "πŸ–‹ Source Texts for an AI Summary",
929
+ elem_classes=["example-button", "example-paper-alt"],
930
+ elem_id="example_4_button"
931
+ )
932
+ with gr.Column(scale=1):
933
+ example_5_btn = gr.Button(
934
+ "πŸ” Evidence that Support Question Answering",
935
+ elem_classes=["example-button", "example-movie-alt"],
936
+ elem_id="example_5_button"
937
+ )
938
+ with gr.Column(scale=1):
939
+ example_6_btn = gr.Button(
940
+ "❌ Misinformation (Corrupted Knowledge) in Question Answering",
941
+ elem_classes=["example-button", "example-code-alt"],
942
+ elem_id="example_6_button"
943
+ )
944
+
945
+ state = gr.State(
946
+ value=clear_state()
947
+ )
948
+
949
+ basic_tab = gr.Tab("Demo")
950
+ with basic_tab:
951
+ # gr.Markdown("## Demo")
952
+ gr.Markdown(
953
+ "Enter your context and instruction below to try out AttnTrace! You can also click on the example buttons above to load pre-configured examples."
954
+ )
955
+
956
+ gr.Markdown(
957
+ '**Color Legend for Context Traceback (by ranking):** <span style="background-color: #FF4444; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Red</span> = 1st (most important) | <span style="background-color: #FF8C42; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Orange</span> = 2nd | <span style="background-color: #FFD93D; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Golden</span> = 3rd | <span style="background-color: #FFF280; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Yellow</span> = 4th-5th | <span style="background-color: #FFF9C4; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Light</span> = 6th+'
958
+ )
959
+
960
+
961
+ # Top section: Wide Context box with tabs
962
+ with gr.Row():
963
+ with gr.Column(scale=1):
964
+ with gr.Tabs() as basic_context_tabs:
965
+ with gr.TabItem("Context", id=0):
966
+ basic_context_box = gr.Textbox(
967
+ placeholder="Enter context...",
968
+ show_label=False,
969
+ value="",
970
+ lines=6,
971
+ max_lines=6,
972
+ elem_id="basic_context_box",
973
+ autoscroll=False,
974
+ )
975
+ with gr.TabItem("Context with highlighted traceback results", id=1, visible=True) as basic_sources_in_context_tab:
976
+ basic_sources_in_context_box = HighlightedTextbox(
977
+ value=[("Click on a sentence in the response below to see highlighted traceback results here.", None)],
978
+ show_legend_label=False,
979
+ show_label=False,
980
+ show_legend=False,
981
+ interactive=False,
982
+ elem_id="basic_sources_in_context_box",
983
+ )
984
+
985
+ # Error messages
986
+ basic_generate_error_box = HighlightedTextbox(
987
+ show_legend_label=False,
988
+ show_label=False,
989
+ show_legend=False,
990
+ visible=False,
991
+ interactive=False,
992
+ container=False,
993
+ )
994
+
995
+ # Bottom section: Left (instruction + button + response), Right (response selection)
996
+ with gr.Row(equal_height=True):
997
+ # Left: Instruction + Button + Response
998
+ with gr.Column(scale=1):
999
+ basic_query_box = gr.Textbox(
1000
+ label="Instruction",
1001
+ placeholder="Enter an instruction...",
1002
+ value="",
1003
+ lines=3,
1004
+ max_lines=3,
1005
+ )
1006
+
1007
+ unified_response_button = gr.Button(
1008
+ "Generate/Use Response",
1009
+ variant="primary",
1010
+ size="lg"
1011
+ )
1012
+
1013
+ response_input_box = gr.Textbox(
1014
+ label="Response (Editable)",
1015
+ placeholder="Response will appear here after generation, or type your own response for traceback...",
1016
+ lines=8,
1017
+ max_lines=8,
1018
+ info="Leave empty and click button to generate from LLM, or type your own response to use for traceback"
1019
+ )
1020
+
1021
+ # Right: Response for attribution selection
1022
+ with gr.Column(scale=1):
1023
+ basic_response_box = gr.HighlightedText(
1024
+ label="Click to select text for traceback!",
1025
+ value=[("Click the 'Generate/Use Response' button on the left to see response text here for traceback analysis.", None)],
1026
+ interactive=False,
1027
+ combine_adjacent=False,
1028
+ show_label=True,
1029
+ show_legend=False,
1030
+ elem_id="basic_response_box",
1031
+ visible=True,
1032
+ )
1033
+
1034
+ # Button for full response traceback
1035
+ full_response_traceback_button = gr.Button(
1036
+ "πŸ” Traceback Entire Response",
1037
+ variant="secondary",
1038
+ size="sm"
1039
+ )
1040
+
1041
+ # Hidden error box and dummy elements
1042
+ basic_attribute_error_box = HighlightedTextbox(
1043
+ show_legend_label=False,
1044
+ show_label=False,
1045
+ show_legend=False,
1046
+ visible=False,
1047
+ interactive=False,
1048
+ container=False,
1049
+ )
1050
+ dummy_basic_sources_box = gr.Textbox(
1051
+ visible=False, interactive=False, container=False
1052
+ )
1053
+
1054
+
1055
+ # Only a single (AttnTrace) method and model in this simplified version
1056
+
1057
+ def basic_clear_state():
1058
+ state = clear_state()
1059
+ return (
1060
+ "", # basic_context_box
1061
+ "", # basic_query_box
1062
+ "", # response_input_box
1063
+ gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]), # basic_response_box - keep visible
1064
+ gr.update(selected=0), # basic_context_tabs - switch to first tab
1065
+ state,
1066
+ )
1067
+
1068
+ # Defining behavior of various interactions for the basic tab
1069
+ basic_tab.select(
1070
+ fn=basic_clear_state,
1071
+ inputs=[],
1072
+ outputs=[
1073
+ basic_context_box,
1074
+ basic_query_box,
1075
+ response_input_box,
1076
+ basic_response_box,
1077
+ basic_context_tabs,
1078
+ state,
1079
+ ],
1080
+ )
1081
+ for component in [basic_context_box, basic_query_box]:
1082
+ component.change(
1083
+ basic_update,
1084
+ [basic_context_box, basic_query_box, state],
1085
+ [
1086
+ basic_response_box,
1087
+ basic_context_tabs,
1088
+ state,
1089
+ ],
1090
+ )
1091
+ # Example button event handlers - now update both UI and state
1092
+ outputs_for_examples = [
1093
+ basic_context_box,
1094
+ basic_query_box,
1095
+ state,
1096
+ response_input_box,
1097
+ basic_response_box,
1098
+ basic_context_tabs,
1099
+ ]
1100
+ example_1_btn.click(
1101
+ fn=partial(load_an_example, run_example_1),
1102
+ inputs=[state],
1103
+ outputs=outputs_for_examples
1104
+ )
1105
+ example_2_btn.click(
1106
+ fn=partial(load_an_example, run_example_2),
1107
+ inputs=[state],
1108
+ outputs=outputs_for_examples
1109
+ )
1110
+ example_3_btn.click(
1111
+ fn=partial(load_an_example, run_example_3),
1112
+ inputs=[state],
1113
+ outputs=outputs_for_examples
1114
+ )
1115
+ example_4_btn.click(
1116
+ fn=partial(load_an_example, run_example_4),
1117
+ inputs=[state],
1118
+ outputs=outputs_for_examples
1119
+ )
1120
+ example_5_btn.click(
1121
+ fn=partial(load_an_example, run_example_5),
1122
+ inputs=[state],
1123
+ outputs=outputs_for_examples
1124
+ )
1125
+ example_6_btn.click(
1126
+ fn=partial(load_an_example, run_example_6),
1127
+ inputs=[state],
1128
+ outputs=outputs_for_examples
1129
+ )
1130
+
1131
+ unified_response_button.click(
1132
+ fn=lambda: None,
1133
+ inputs=[],
1134
+ outputs=[],
1135
+ js=get_scroll_js_code("basic_response_box"),
1136
+ )
1137
+ basic_response_box.change(
1138
+ fn=lambda: None,
1139
+ inputs=[],
1140
+ outputs=[],
1141
+ js=get_scroll_js_code("basic_sources_in_context_box"),
1142
+ )
1143
+ # Add immediate tab switch on response selection
1144
+ def immediate_tab_switch():
1145
+ return (
1146
+ gr.update(value=[("πŸ”„ Processing traceback... Please wait...", None)]), # Show progress message
1147
+ gr.update(selected=1), # Switch to annotation tab immediately
1148
+ )
1149
+
1150
+ basic_response_box.select(
1151
+ fn=immediate_tab_switch,
1152
+ inputs=[],
1153
+ outputs=[basic_sources_in_context_box, basic_context_tabs],
1154
+ queue=False, # Execute immediately without queue
1155
+ )
1156
+
1157
+ basic_response_box.select(
1158
+ fn=basic_get_scores_and_sources,
1159
+ inputs=[basic_response_box, state],
1160
+ outputs=[
1161
+ basic_sources_in_context_box,
1162
+ basic_context_tabs,
1163
+ basic_sources_in_context_tab,
1164
+ dummy_basic_sources_box,
1165
+ basic_attribute_error_box,
1166
+ state,
1167
+ ],
1168
+ show_progress="full",
1169
+ )
1170
+ basic_response_box.select(
1171
+ fn=basic_update_highlighted_response,
1172
+ inputs=[state],
1173
+ outputs=[basic_response_box, state],
1174
+ )
1175
+
1176
+ # Full response traceback button
1177
+ full_response_traceback_button.click(
1178
+ fn=immediate_tab_switch,
1179
+ inputs=[],
1180
+ outputs=[basic_sources_in_context_box, basic_context_tabs],
1181
+ queue=False, # Execute immediately without queue
1182
+ )
1183
+
1184
+ full_response_traceback_button.click(
1185
+ fn=basic_get_scores_and_sources_full_response,
1186
+ inputs=[state],
1187
+ outputs=[
1188
+ basic_sources_in_context_box,
1189
+ basic_context_tabs,
1190
+ basic_sources_in_context_tab,
1191
+ dummy_basic_sources_box,
1192
+ basic_attribute_error_box,
1193
+ state,
1194
+ ],
1195
+ show_progress="full",
1196
+ )
1197
+
1198
+ dummy_basic_sources_box.change(
1199
+ fn=lambda: None,
1200
+ inputs=[],
1201
+ outputs=[],
1202
+ js=get_scroll_js_code("basic_sources_in_context_box"),
1203
+ )
1204
+
1205
+ # Unified response handler
1206
+ unified_response_button.click(
1207
+ fn=unified_response_handler,
1208
+ inputs=[response_input_box, state],
1209
+ outputs=[state, response_input_box, basic_response_box, basic_generate_error_box]
1210
+ )
1211
+
1212
+
1213
+ # gr.Markdown(
1214
+ # "Please do not interact with elements while generation/attribution is in progress. This may cause errors. You can refresh the page if you run into issues because of this."
1215
+ # )
1216
+
1217
+ demo.launch(show_api=False, share=True)
1218
+