lucalp commited on
Commit
545bc06
·
1 Parent(s): b074257

adding patch counts and cleaning up

Browse files
Files changed (1) hide show
  1. app.py +253 -154
app.py CHANGED
@@ -3,15 +3,27 @@ import gradio as gr
3
  import torch
4
  import itertools # For color cycling
5
  import tiktoken # For GPT-4 tokenizer
6
- from transformers import AutoTokenizer, AutoModel # For Llama3 tokenizer
7
 
8
  # Bytelatent imports (assuming they are in the python path)
9
- from bytelatent.data.file_util import get_fs
10
- from bytelatent.generate_patcher import patcher_nocache
11
- from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
12
- from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
13
- from bytelatent.args import TrainArgs
14
- from download_blt_weights import main as ensure_present
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # --- Global Setup ---
17
 
@@ -27,113 +39,117 @@ LLAMA3_MODEL_NAME = "meta-llama/Meta-Llama-3-8B" # Or choose another variant lik
27
 
28
  def create_bytelatent_highlight_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
29
  """Generates data for gr.HighlightedText based on bytelatent patches."""
30
- # (Keep the function from the previous version - no changes needed)
 
31
  if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
32
  return None
33
  patch_lengths = patch_lengths_tensor.tolist()
34
  all_tokens = tokens_tensor.tolist()
35
  highlighted_data = []
36
  current_token_index = 0
37
- color_cycler = itertools.cycle(colors)
 
38
  for i, length in enumerate(patch_lengths):
39
  if length <= 0: continue
40
  patch_token_ids = all_tokens[current_token_index : current_token_index + length]
41
  if not patch_token_ids: continue
42
  try: patch_text = tokenizer.decode(patch_token_ids)
43
  except Exception as decode_err:
44
- print(f"Warning: Bytelatent patch decoding failed: {decode_err}")
45
- patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
46
  patch_label = f"BL Patch {i+1}"
47
  highlighted_data.append((patch_text, patch_label))
 
48
  current_token_index += length
 
 
49
  if current_token_index != len(all_tokens):
50
  print(f"Warning: Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
51
  remaining_tokens = all_tokens[current_token_index:]
52
  if remaining_tokens:
53
- try: remaining_text = tokenizer.decode(remaining_tokens)
54
- except Exception: remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
55
- highlighted_data.append((remaining_text, "BL Remainder"))
56
- return highlighted_data
 
 
57
 
58
 
59
  def create_tiktoken_highlight_data(prompt, colors):
60
  """Generates data for gr.HighlightedText based on tiktoken (gpt-4) tokens."""
61
- # (Keep the function from the previous version - no changes needed)
62
  try:
63
  enc = tiktoken.get_encoding("cl100k_base")
64
  tiktoken_ids = enc.encode(prompt)
65
  highlighted_data = []
66
- color_cycler = itertools.cycle(colors)
67
  for i, token_id in enumerate(tiktoken_ids):
68
  try: token_text = enc.decode([token_id])
69
  except UnicodeDecodeError:
70
- try:
71
- token_bytes = enc.decode_single_token_bytes(token_id)
72
- token_text = f"[Bytes: {token_bytes.hex()}]"
73
- except Exception: token_text = "[Decode Error]"
74
  except Exception as e:
75
- print(f"Unexpected tiktoken decode error: {e}")
76
- token_text = "[Decode Error]"
77
  token_label = f"GPT4 Tk {i+1}"
78
  highlighted_data.append((token_text, token_label))
79
- print(f"Tiktoken processing complete. Found {len(tiktoken_ids)} tokens.")
80
- return highlighted_data
 
81
  except ImportError:
82
- print("Error: tiktoken library not found. Please install it: pip install tiktoken")
83
- return [("tiktoken library not installed.", "Error")]
84
  except Exception as tiktoken_err:
85
  print(f"Error during tiktoken processing: {tiktoken_err}")
86
- return [(f"Error processing with tiktoken: {str(tiktoken_err)}", "Error")]
87
 
88
 
89
  def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
90
  """Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
91
  try:
92
  # Load Llama 3 tokenizer from Hugging Face Hub
93
- # This might download the tokenizer files on the first run
94
- # May require `huggingface-cli login` if model is private or gated
95
  print(f"Loading Llama 3 tokenizer: {model_name}")
96
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
97
  print("Llama 3 tokenizer loaded.")
98
 
99
  # Encode the prompt
100
  llama_token_ids = tokenizer.encode(prompt)
101
 
102
  highlighted_data = []
103
- color_cycler = itertools.cycle(colors)
104
 
105
  for i, token_id in enumerate(llama_token_ids):
106
  try:
107
- # Decode individual token. Llama/SentencePiece tokenizers usually handle this well.
108
  token_text = tokenizer.decode([token_id])
109
- # Special case: Handle potential leading space added by sentencepiece during decode
110
- # if token_text.startswith(' '): # Check if this improves visualization
111
- # token_text = token_text[1:] # Remove leading space visual artifact? Test this.
112
  except Exception as e:
113
- print(f"Unexpected Llama 3 decode error for token {token_id}: {e}")
114
- token_text = "[Decode Error]"
115
 
116
  token_label = f"Llama3 Tk {i+1}" # Clearer label prefix
117
  highlighted_data.append((token_text, token_label))
118
 
119
- print(f"Llama 3 processing complete. Found {len(llama_token_ids)} tokens.")
120
- return highlighted_data
 
121
 
122
  except ImportError:
123
- print("Error: transformers or sentencepiece library not found. Please install them: pip install transformers sentencepiece")
124
- return [("transformers/sentencepiece library not installed.", "Error")]
125
  except OSError as e:
126
  # Handle errors like model not found, network issues, authentication needed
127
  print(f"Error loading Llama 3 tokenizer '{model_name}': {e}")
 
128
  if "authentication" in str(e).lower():
129
- return [(f"Authentication required for Llama 3 tokenizer '{model_name}'. Use `huggingface-cli login`.", "Error")]
130
- else:
131
- return [(f"Could not load Llama 3 tokenizer '{model_name}'. Check model name and network. Error: {e}", "Error")]
132
  except Exception as llama_err:
133
  print(f"Error during Llama 3 processing: {llama_err}")
134
  import traceback
135
  traceback.print_exc() # Print full traceback for debugging
136
- return [(f"Error processing with Llama 3: {str(llama_err)}", "Error")]
137
 
138
 
139
  # --- Main Processing Function ---
@@ -141,7 +157,7 @@ def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
141
  def process_text(prompt: str, model_name: str = "blt-1b"):
142
  """
143
  Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
144
- returning visualizations and status.
145
 
146
  Args:
147
  prompt: The input text string from the Gradio interface.
@@ -151,100 +167,136 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
151
  A tuple containing:
152
  - Matplotlib Figure for the entropy plot (or None).
153
  - List of tuples for bytelatent gr.HighlightedText (or None).
 
154
  - List of tuples for tiktoken gr.HighlightedText (or None).
 
155
  - List of tuples for Llama 3 gr.HighlightedText (or None).
 
156
  - Status/Error message string.
157
  """
158
  fig = None
159
  bl_highlighted_data = None
160
  tk_highlighted_data = None
161
  llama_highlighted_data = None
 
 
 
162
  status_message = "Starting processing..."
163
 
164
  # --- 1. Tiktoken Processing (Independent) ---
165
  status_message += "\nProcessing with Tiktoken (gpt-4)..."
166
- tk_highlighted_data = create_tiktoken_highlight_data(prompt, VIZ_COLORS)
167
  if tk_highlighted_data and tk_highlighted_data[0][1] == "Error":
168
- status_message += f"\nTiktoken Error: {tk_highlighted_data[0][0]}"
 
169
  else:
170
- status_message += "\nTiktoken processing successful."
 
171
 
172
  # --- 2. Llama 3 Processing (Independent) ---
173
  status_message += "\nProcessing with Llama 3 tokenizer..."
174
- llama_highlighted_data = create_llama3_highlight_data(prompt, VIZ_COLORS)
175
  if llama_highlighted_data and llama_highlighted_data[0][1] == "Error":
176
- status_message += f"\nLlama 3 Error: {llama_highlighted_data[0][0]}"
 
177
  else:
178
- status_message += "\nLlama 3 processing successful."
 
179
 
180
  # --- 3. Bytelatent Processing ---
181
- try:
182
- status_message += f"\nLoading entropy model for '{model_name}'..."
183
- # (Bytelatent loading code remains the same as previous version)
184
- consolidated_path = os.path.join("hf-weights", model_name)
185
- train_args_path = os.path.join(consolidated_path, "params.json")
186
- if not os.path.exists(train_args_path): raise FileNotFoundError(f"Bytelatent training args not found at {train_args_path}.")
187
- fs = get_fs(train_args_path); train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
188
- bl_tokenizer = train_args.data.tokenizer_args.build(); assert isinstance(bl_tokenizer, BltTokenizer)
189
- patcher_args = train_args.data.patcher_args.model_copy(deep=True); patcher_args.realtime_patching = True
190
- device = "cuda" if torch.cuda.is_available() else "cpu"; print(f"Using Bytelatent device: {device}")
191
- patcher_args.patching_device = device; patcher_args.device = device
192
- entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
193
- if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Bytelatent entropy model directory not found at {entropy_model_dir}.")
194
- patcher_args.entropy_model_checkpoint_dir = entropy_model_dir; bl_patcher = patcher_args.build()
195
- status_message += "\nBytelatent model loaded."
196
-
197
- # --- Processing ---
198
- status_message += "\nRunning Bytelatent patching..."
199
- print(f"Processing prompt with Bytelatent: '{prompt}'")
200
- # Limit prompt length for bytelatent if necessary
201
- prompt_bytes = prompt.encode('utf-8')
202
- if len(prompt_bytes) > 512:
203
- print(f"Warning: Prompt exceeds 512 bytes ({len(prompt_bytes)}). Truncating for Bytelatent.")
204
- prompt_bl = prompt_bytes[:512].decode('utf-8', errors='ignore')
205
- status_message += "\nWarning: Prompt truncated to 512 bytes for Bytelatent."
206
- else:
207
- prompt_bl = prompt
208
-
209
- results = patcher_nocache([prompt_bl], tokenizer=bl_tokenizer, patcher=bl_patcher)
210
-
211
- if not results:
212
- print("Bytelatent processing returned no results.")
213
- status_message += "\nBytelatent Warning: Processing completed, but no results were generated."
214
- else:
215
- batch_patch_lengths, batch_scores, batch_tokens = results
216
- patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
217
- # --- Visualization Data Generation ---
218
- try: decoded_output_for_plot = bl_tokenizer.decode(tokens.tolist())
219
- except Exception as decode_err:
220
- print(f"Warning: Error decoding full sequence for plot: {decode_err}")
221
- decoded_output_for_plot = prompt_bl # Use truncated prompt for plot if decode fails
222
- fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=bl_patcher.threshold)
223
- bl_highlighted_data = create_bytelatent_highlight_data(bl_tokenizer, patch_lengths, tokens, VIZ_COLORS)
224
- status_message += "\nBytelatent processing and visualization successful."
225
- print("Bytelatent processing and decoding complete.")
226
-
227
- except FileNotFoundError as e:
228
- print(f"Bytelatent Error: {e}")
229
- status_message += f"\nBytelatent FileNotFoundError: {str(e)}"
230
- except Exception as e:
231
- print(f"An unexpected Bytelatent error occurred: {e}")
232
- import traceback
233
- traceback.print_exc()
234
- status_message += f"\nBytelatent Unexpected Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Return all generated data and the final status message
237
- return fig, bl_highlighted_data, tk_highlighted_data, llama_highlighted_data, status_message
238
 
239
 
240
  # --- Gradio Interface ---
241
 
242
  # Create color maps for HighlightedText dynamically
243
- MAX_EXPECTED_SEGMENTS = 1000 # Increase max expected segments further
244
- common_error_map = {"Error": "#FF0000"} # Red for errors
245
 
246
  bytelatent_color_map = {f"BL Patch {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
247
- bytelatent_color_map["BL Remainder"] = "#808080"; bytelatent_color_map.update(common_error_map)
248
 
249
  tiktoken_color_map = {f"GPT4 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
250
  tiktoken_color_map.update(common_error_map)
@@ -253,71 +305,118 @@ llama3_color_map = {f"Llama3 Tk {i+1}": color for i, color in zip(range(MAX_EXPE
253
  llama3_color_map.update(common_error_map)
254
 
255
 
256
- with gr.Blocks(theme=gr.themes.Soft()) as iface:
257
- gr.Markdown("# BLT's Entropy Patcher Visualisation") # Updated Title
258
  gr.Markdown(
259
- "Enter text to visualize its segmentation according to different tokenizers:\n"
260
- "1. **BLT:** Entropy plot and text segmented by dynamic patches (Input limited to 512 bytes).\n"
261
  "2. **Tiktoken (GPT-4):** Text segmented by `cl100k_base` tokens.\n"
262
- "3. **Llama 3:** Text segmented by the `meta-llama/Meta-Llama-3-8B` tokenizer."
263
  )
264
 
265
  with gr.Row():
266
- with gr.Column(scale=1): # Input Column
267
  prompt_input = gr.Textbox(
268
  label="Input Prompt",
269
  value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
270
  placeholder="Enter text here...",
271
- max_length=2048, # Allow even longer input, Bytelatent will truncate
272
  lines=5,
273
- info="Processing is limited to the first 512 bytes of the input."
274
  )
275
  submit_button = gr.Button("Generate Visualizations", variant="primary")
276
- status_output = gr.Textbox(label="Processing Status", interactive=False, lines=5)
277
-
278
- with gr.Column(scale=2): # Output Column
279
- gr.Markdown("### BLT's Entropy Patcher Output (`100m`)")
280
- highlighted_output_bl = gr.HighlightedText(
281
- label="Bytelatent Patched Text",
282
- color_map=bytelatent_color_map,
283
- show_legend=False, # Legend can get very long, disable for compactness
284
- show_inline_category=False,
285
- )
286
- plot_output = gr.Plot(label="Bytelatent Entropy vs. Token Index")
287
-
288
- gr.Markdown("### Tiktoken Output (`cl100k_base` for GPT-4)")
289
- highlighted_output_tk = gr.HighlightedText(
290
- label="Tiktoken Segmented Text",
291
- color_map=tiktoken_color_map,
292
- show_legend=False,
293
- show_inline_category=False,
294
- )
295
-
296
- gr.Markdown(f"### Llama 3 Output (`{LLAMA3_MODEL_NAME}`)")
297
- highlighted_output_llama = gr.HighlightedText(
298
- label="Llama 3 Segmented Text",
299
- color_map=llama3_color_map,
300
- show_legend=False,
301
- show_inline_category=False,
302
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # Define the action for the button click
305
  submit_button.click(
306
  fn=process_text,
307
  inputs=prompt_input,
308
- # Ensure order matches the 5 return values of process_text
309
  outputs=[
310
- plot_output,
311
- highlighted_output_bl,
312
- highlighted_output_tk,
313
- highlighted_output_llama,
314
- status_output
 
 
 
315
  ]
316
  )
317
 
318
  # --- Launch the Gradio App ---
319
  if __name__ == "__main__":
320
- print("Please ensure 'tiktoken', 'transformers', and 'sentencepiece' are installed (`pip install tiktoken transformers sentencepiece`)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  print(f"Attempting to use Llama 3 Tokenizer: {LLAMA3_MODEL_NAME}. Ensure you have access (e.g., via `huggingface-cli login` if needed).")
322
- ensure_present(["blt-1b"]) # Ensure bytelatent model is present
323
  iface.launch()
 
3
  import torch
4
  import itertools # For color cycling
5
  import tiktoken # For GPT-4 tokenizer
6
+ from transformers import AutoTokenizer # For Llama3 tokenizer - AutoModel usually not needed just for tokenizer
7
 
8
  # Bytelatent imports (assuming they are in the python path)
9
+ try:
10
+ from bytelatent.data.file_util import get_fs
11
+ from bytelatent.generate_patcher import patcher_nocache
12
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
13
+ from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
14
+ from bytelatent.args import TrainArgs
15
+ from download_blt_weights import main as ensure_present
16
+ BLT_AVAILABLE = True
17
+ except ImportError as e:
18
+ print(f"Warning: Bytelatent libraries not found. Bytelatent functionality will be disabled. Error: {e}")
19
+ BLT_AVAILABLE = False
20
+ # Define dummy classes/functions if BLT is not available to avoid NameErrors later
21
+ class BltTokenizer: pass
22
+ class TrainArgs: pass
23
+ def patcher_nocache(*args, **kwargs): return None
24
+ def plot_entropies(*args, **kwargs): return None
25
+ def ensure_present(*args, **kwargs): pass
26
+
27
 
28
  # --- Global Setup ---
29
 
 
39
 
40
  def create_bytelatent_highlight_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
41
  """Generates data for gr.HighlightedText based on bytelatent patches."""
42
+ if not BLT_AVAILABLE:
43
+ return [("Bytelatent library not available.", "Error")]
44
  if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
45
  return None
46
  patch_lengths = patch_lengths_tensor.tolist()
47
  all_tokens = tokens_tensor.tolist()
48
  highlighted_data = []
49
  current_token_index = 0
50
+ patch_count = 0 # Initialize patch count
51
+ # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-patch
52
  for i, length in enumerate(patch_lengths):
53
  if length <= 0: continue
54
  patch_token_ids = all_tokens[current_token_index : current_token_index + length]
55
  if not patch_token_ids: continue
56
  try: patch_text = tokenizer.decode(patch_token_ids)
57
  except Exception as decode_err:
58
+ print(f"Warning: Bytelatent patch decoding failed: {decode_err}")
59
+ patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
60
  patch_label = f"BL Patch {i+1}"
61
  highlighted_data.append((patch_text, patch_label))
62
+ patch_count += 1 # Increment count for each valid patch added
63
  current_token_index += length
64
+
65
+ # Handle remainder separately, don't count it as a 'patch'
66
  if current_token_index != len(all_tokens):
67
  print(f"Warning: Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
68
  remaining_tokens = all_tokens[current_token_index:]
69
  if remaining_tokens:
70
+ try: remaining_text = tokenizer.decode(remaining_tokens)
71
+ except Exception: remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
72
+ highlighted_data.append((remaining_text, "BL Remainder"))
73
+
74
+ # Return both highlighted data and the calculated patch count
75
+ return highlighted_data, patch_count
76
 
77
 
78
  def create_tiktoken_highlight_data(prompt, colors):
79
  """Generates data for gr.HighlightedText based on tiktoken (gpt-4) tokens."""
 
80
  try:
81
  enc = tiktoken.get_encoding("cl100k_base")
82
  tiktoken_ids = enc.encode(prompt)
83
  highlighted_data = []
84
+ # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-token
85
  for i, token_id in enumerate(tiktoken_ids):
86
  try: token_text = enc.decode([token_id])
87
  except UnicodeDecodeError:
88
+ try:
89
+ token_bytes = enc.decode_single_token_bytes(token_id)
90
+ token_text = f"[Bytes: {token_bytes.hex()}]"
91
+ except Exception: token_text = "[Decode Error]"
92
  except Exception as e:
93
+ print(f"Unexpected tiktoken decode error: {e}")
94
+ token_text = "[Decode Error]"
95
  token_label = f"GPT4 Tk {i+1}"
96
  highlighted_data.append((token_text, token_label))
97
+ token_count = len(tiktoken_ids)
98
+ print(f"Tiktoken processing complete. Found {token_count} tokens.")
99
+ return highlighted_data, token_count
100
  except ImportError:
101
+ print("Error: tiktoken library not found. Please install it: pip install tiktoken")
102
+ return [("tiktoken library not installed.", "Error")], 0
103
  except Exception as tiktoken_err:
104
  print(f"Error during tiktoken processing: {tiktoken_err}")
105
+ return [(f"Error processing with tiktoken: {str(tiktoken_err)}", "Error")], 0
106
 
107
 
108
  def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
109
  """Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
110
  try:
111
  # Load Llama 3 tokenizer from Hugging Face Hub
 
 
112
  print(f"Loading Llama 3 tokenizer: {model_name}")
113
+ # Use trust_remote_code=True if required by the specific model revision
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name) #, trust_remote_code=True)
115
  print("Llama 3 tokenizer loaded.")
116
 
117
  # Encode the prompt
118
  llama_token_ids = tokenizer.encode(prompt)
119
 
120
  highlighted_data = []
121
+ # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-token
122
 
123
  for i, token_id in enumerate(llama_token_ids):
124
  try:
125
+ # Decode individual token.
126
  token_text = tokenizer.decode([token_id])
 
 
 
127
  except Exception as e:
128
+ print(f"Unexpected Llama 3 decode error for token {token_id}: {e}")
129
+ token_text = "[Decode Error]"
130
 
131
  token_label = f"Llama3 Tk {i+1}" # Clearer label prefix
132
  highlighted_data.append((token_text, token_label))
133
 
134
+ token_count = len(llama_token_ids)
135
+ print(f"Llama 3 processing complete. Found {token_count} tokens.")
136
+ return highlighted_data, token_count
137
 
138
  except ImportError:
139
+ print("Error: transformers or sentencepiece library not found. Please install them: pip install transformers sentencepiece")
140
+ return [("transformers/sentencepiece library not installed.", "Error")], 0
141
  except OSError as e:
142
  # Handle errors like model not found, network issues, authentication needed
143
  print(f"Error loading Llama 3 tokenizer '{model_name}': {e}")
144
+ error_msg = f"Could not load Llama 3 tokenizer '{model_name}'. Check model name and network."
145
  if "authentication" in str(e).lower():
146
+ error_msg = f"Authentication required for Llama 3 tokenizer '{model_name}'. Use `huggingface-cli login`."
147
+ return [(f"{error_msg} Error: {e}", "Error")], 0
 
148
  except Exception as llama_err:
149
  print(f"Error during Llama 3 processing: {llama_err}")
150
  import traceback
151
  traceback.print_exc() # Print full traceback for debugging
152
+ return [(f"Error processing with Llama 3: {str(llama_err)}", "Error")], 0
153
 
154
 
155
  # --- Main Processing Function ---
 
157
  def process_text(prompt: str, model_name: str = "blt-1b"):
158
  """
159
  Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
160
+ returning visualizations, counts, and status.
161
 
162
  Args:
163
  prompt: The input text string from the Gradio interface.
 
167
  A tuple containing:
168
  - Matplotlib Figure for the entropy plot (or None).
169
  - List of tuples for bytelatent gr.HighlightedText (or None).
170
+ - Integer count of bytelatent patches.
171
  - List of tuples for tiktoken gr.HighlightedText (or None).
172
+ - Integer count of tiktoken tokens.
173
  - List of tuples for Llama 3 gr.HighlightedText (or None).
174
+ - Integer count of Llama 3 tokens.
175
  - Status/Error message string.
176
  """
177
  fig = None
178
  bl_highlighted_data = None
179
  tk_highlighted_data = None
180
  llama_highlighted_data = None
181
+ bl_count = 0
182
+ tk_count = 0
183
+ llama_count = 0
184
  status_message = "Starting processing..."
185
 
186
  # --- 1. Tiktoken Processing (Independent) ---
187
  status_message += "\nProcessing with Tiktoken (gpt-4)..."
188
+ tk_highlighted_data, tk_count_calc = create_tiktoken_highlight_data(prompt, VIZ_COLORS)
189
  if tk_highlighted_data and tk_highlighted_data[0][1] == "Error":
190
+ status_message += f"\nTiktoken Error: {tk_highlighted_data[0][0]}"
191
+ tk_count = 0 # Ensure count is 0 on error
192
  else:
193
+ tk_count = tk_count_calc # Assign calculated count
194
+ status_message += f"\nTiktoken processing successful ({tk_count} tokens)."
195
 
196
  # --- 2. Llama 3 Processing (Independent) ---
197
  status_message += "\nProcessing with Llama 3 tokenizer..."
198
+ llama_highlighted_data, llama_count_calc = create_llama3_highlight_data(prompt, VIZ_COLORS)
199
  if llama_highlighted_data and llama_highlighted_data[0][1] == "Error":
200
+ status_message += f"\nLlama 3 Error: {llama_highlighted_data[0][0]}"
201
+ llama_count = 0 # Ensure count is 0 on error
202
  else:
203
+ llama_count = llama_count_calc # Assign calculated count
204
+ status_message += f"\nLlama 3 processing successful ({llama_count} tokens)."
205
 
206
  # --- 3. Bytelatent Processing ---
207
+ if BLT_AVAILABLE:
208
+ try:
209
+ status_message += f"\nLoading Bytelatent entropy model for '{model_name}'..."
210
+ # (Bytelatent loading code remains the same)
211
+ consolidated_path = os.path.join("hf-weights", model_name)
212
+ train_args_path = os.path.join(consolidated_path, "params.json")
213
+ if not os.path.exists(train_args_path): raise FileNotFoundError(f"BLT training args not found at {train_args_path}.")
214
+ fs = get_fs(train_args_path); train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
215
+ bl_tokenizer = train_args.data.tokenizer_args.build(); assert isinstance(bl_tokenizer, BltTokenizer)
216
+ patcher_args = train_args.data.patcher_args.model_copy(deep=True); patcher_args.realtime_patching = True
217
+ device = "cuda" if torch.cuda.is_available() else "cpu"; print(f"Using BLT device: {device}")
218
+ patcher_args.patching_device = device; patcher_args.device = device
219
+ entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
220
+ if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
221
+ patcher_args.entropy_model_checkpoint_dir = entropy_model_dir; bl_patcher = patcher_args.build()
222
+ status_message += "\nBytelatent entropy model loaded."
223
+
224
+ # --- Processing ---
225
+ status_message += "\nRunning Bytelatent entropy model patching..."
226
+ print(f"Processing prompt with entropy model: '{prompt}'")
227
+ prompt_bytes = prompt.encode('utf-8')
228
+ max_bytes = 512 # Define max bytes
229
+ if len(prompt_bytes) > max_bytes:
230
+ print(f"Warning: Prompt exceeds {max_bytes} bytes ({len(prompt_bytes)}). Truncating for entropy model.")
231
+ # Find the byte position that corresponds to the last full character within the limit
232
+ # This avoids splitting a multi-byte character
233
+ try:
234
+ last_char_pos = prompt_bytes[:max_bytes].rfind(b' ') # Simple whitespace split point find, might not be perfect
235
+ if last_char_pos == -1: # If no space, truncate hard (less ideal)
236
+ prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='ignore')
237
+ else:
238
+ prompt_bl = prompt_bytes[:last_char_pos].decode('utf-8', errors='ignore')
239
+
240
+ except Exception: # Fallback to simple truncation on decode errors
241
+ prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='ignore')
242
+
243
+ status_message += f"\nWarning: Prompt truncated to approx {len(prompt_bl.encode('utf-8'))} bytes for Bytelatent entropy model."
244
+ else:
245
+ prompt_bl = prompt
246
+
247
+ results = patcher_nocache([prompt_bl], tokenizer=bl_tokenizer, patcher=bl_patcher)
248
+
249
+ if not results:
250
+ print("Bytelatent entropy processing returned no results.")
251
+ status_message += "\nBytelatent entropy model warning: Processing completed, but no results were generated."
252
+ bl_highlighted_data = [("No patches generated.", "Info")]
253
+ bl_count = 0
254
+ else:
255
+ batch_patch_lengths, batch_scores, batch_tokens = results
256
+ patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
257
+ # --- Visualization Data Generation ---
258
+ try: decoded_output_for_plot = bl_tokenizer.decode(tokens.tolist())
259
+ except Exception as decode_err:
260
+ print(f"Warning: Error decoding full sequence for plot: {decode_err}")
261
+ decoded_output_for_plot = prompt_bl # Use truncated prompt for plot if decode fails
262
+
263
+ fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=bl_patcher.threshold)
264
+ bl_highlighted_data, bl_count_calc = create_bytelatent_highlight_data(bl_tokenizer, patch_lengths, tokens, VIZ_COLORS)
265
+ bl_count = bl_count_calc # Assign calculated count
266
+
267
+ status_message += f"\nBytelatent entropy model processing and visualization successful ({bl_count} patches)."
268
+ print("Bytelatent Entropy model processing and decoding complete.")
269
+
270
+ except FileNotFoundError as e:
271
+ print(f"Bytelatent Error: {e}")
272
+ status_message += f"\nBytelatent FileNotFoundError: {str(e)}"
273
+ bl_highlighted_data = [(f"Bytelatent Error: {e}", "Error")]
274
+ bl_count = 0
275
+ except Exception as e:
276
+ print(f"An unexpected Bytelatent error occurred: {e}")
277
+ import traceback
278
+ traceback.print_exc()
279
+ status_message += f"\nBytelatent Unexpected Error: {str(e)}"
280
+ bl_highlighted_data = [(f"Bytelatent Error: {e}", "Error")]
281
+ bl_count = 0
282
+ else:
283
+ status_message += "\nBytelatent processing skipped (library not found)."
284
+ bl_highlighted_data = [("Bytelatent library not available.", "Error")]
285
+ bl_count = 0
286
+ fig = None # Ensure fig is None if BLT is skipped
287
 
288
  # Return all generated data and the final status message
289
+ return fig, bl_highlighted_data, bl_count, tk_highlighted_data, tk_count, llama_highlighted_data, llama_count, status_message
290
 
291
 
292
  # --- Gradio Interface ---
293
 
294
  # Create color maps for HighlightedText dynamically
295
+ MAX_EXPECTED_SEGMENTS = 2000 # Increased max segments further just in case
296
+ common_error_map = {"Error": "#FF0000", "Info": "#808080"} # Red for errors, Gray for info
297
 
298
  bytelatent_color_map = {f"BL Patch {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
299
+ bytelatent_color_map["BL Remainder"] = "#AAAAAA"; bytelatent_color_map.update(common_error_map)
300
 
301
  tiktoken_color_map = {f"GPT4 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
302
  tiktoken_color_map.update(common_error_map)
 
305
  llama3_color_map.update(common_error_map)
306
 
307
 
308
+ with gr.Blocks(theme=gr.themes.Origin()) as iface:
309
+ gr.Markdown("# BLT's Entropy-based Patcher vs. Tokenizer Visualisation")
310
  gr.Markdown(
311
+ "Enter text to visualize its segmentation according to different methods:\n"
312
+ "1. **Byte Latent Transformer (BLT):** Entropy-based patching plot and patched text (_for this space ONLY_ - limited to ~512 bytes).\n"
313
  "2. **Tiktoken (GPT-4):** Text segmented by `cl100k_base` tokens.\n"
314
+ f"3. **Llama 3:** Text segmented by the `{LLAMA3_MODEL_NAME}` tokenizer."
315
  )
316
 
317
  with gr.Row():
318
+ with gr.Column(scale=1): # Input Column
319
  prompt_input = gr.Textbox(
320
  label="Input Prompt",
321
  value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
322
  placeholder="Enter text here...",
323
+ max_length=512, # Allow even longer input, Bytelatent will truncate
324
  lines=5,
325
+ info="For this space ONLY, processing is limited to ~512 bytes."
326
  )
327
  submit_button = gr.Button("Generate Visualizations", variant="primary")
328
+ status_output = gr.Textbox(label="Processing Status", interactive=False, lines=7) # Increased lines slightly
329
+
330
+ with gr.Column(scale=2): # Output Column
331
+
332
+ # --- Bytelatent Output Area ---
333
+ with gr.Row(equal_height=False): # Use Row to place title and count together
334
+ gr.Markdown("### BLT Entropy Patcher Output (`blt_main_entropy_100m_512w`)")
335
+
336
+ bl_count_output = gr.Number(label="Patch Count", value=0, interactive=False, scale=1, step=1) # Added Number output
337
+ highlighted_output_bl = gr.HighlightedText(
338
+ label="BLT's Entropy-based Patches",
339
+ color_map=bytelatent_color_map,
340
+ show_legend=False, # Legend can get very long
341
+ # show_label=False, # Hide the HighlightedText label as we have the markdown title
342
+ show_inline_category=False,
343
+ # container=False, # Reduces vertical space slightly
344
+ )
345
+ plot_output = gr.Plot(label="Entropy vs. Token Index", show_label=True)
346
+
347
+ # --- Tiktoken Output Area ---
348
+ with gr.Row(equal_height=False):
349
+ gr.Markdown("### Tiktoken Output (`cl100k_base`)")
350
+
351
+ tk_count_output = gr.Number(label="Token Count", value=0, interactive=False, scale=1, step=1) # Added Number output
352
+ highlighted_output_tk = gr.HighlightedText(
353
+ label="Tiktoken Segmented Text",
354
+ color_map=tiktoken_color_map,
355
+ show_legend=False,
356
+ show_inline_category=False,
357
+ # show_label=False,
358
+ # container=False,
359
+ )
360
+
361
+ # --- Llama 3 Output Area ---
362
+ with gr.Row(equal_height=False):
363
+ gr.Markdown(f"### Llama 3 Output (`{LLAMA3_MODEL_NAME}`)")
364
+
365
+ llama_count_output = gr.Number(label="Token Count", value=0, interactive=False, scale=1, step=1) # Added Number output
366
+ highlighted_output_llama = gr.HighlightedText(
367
+ label="Llama 3 Segmented Text",
368
+ color_map=llama3_color_map,
369
+ show_legend=False,
370
+ show_inline_category=False,
371
+ # show_label=False,
372
+ # container=False,
373
+ )
374
 
375
  # Define the action for the button click
376
  submit_button.click(
377
  fn=process_text,
378
  inputs=prompt_input,
379
+ # Ensure order matches the 8 return values of process_text
380
  outputs=[
381
+ plot_output, # fig
382
+ highlighted_output_bl, # bl_highlighted_data
383
+ bl_count_output, # bl_count <-- New
384
+ highlighted_output_tk, # tk_highlighted_data
385
+ tk_count_output, # tk_count <-- New
386
+ highlighted_output_llama,# llama_highlighted_data
387
+ llama_count_output, # llama_count <-- New
388
+ status_output # status_message
389
  ]
390
  )
391
 
392
  # --- Launch the Gradio App ---
393
  if __name__ == "__main__":
394
+ print("Checking required libraries...")
395
+ try:
396
+ import tiktoken
397
+ print("- tiktoken found.")
398
+ except ImportError:
399
+ print("WARNING: 'tiktoken' not found. GPT-4 visualization will fail. Install with: pip install tiktoken")
400
+ try:
401
+ import transformers
402
+ import sentencepiece
403
+ print("- transformers found.")
404
+ print("- sentencepiece found.")
405
+ except ImportError:
406
+ print("WARNING: 'transformers' or 'sentencepiece' not found. Llama 3 visualization will fail. Install with: pip install transformers sentencepiece")
407
+
408
+ if BLT_AVAILABLE:
409
+ print("- Bytelatent libraries found.")
410
+ # Ensure bytelatent model is present only if library is available
411
+ try:
412
+ print("Ensuring Bytelatent model 'blt-1b' weights are present...")
413
+ ensure_present(["blt-1b"])
414
+ print("Bytelatent model check complete.")
415
+ except Exception as blt_dl_err:
416
+ print(f"WARNING: Failed to ensure Bytelatent model presence: {blt_dl_err}")
417
+ else:
418
+ print("INFO: Bytelatent libraries not found, skipping related functionality.")
419
+
420
  print(f"Attempting to use Llama 3 Tokenizer: {LLAMA3_MODEL_NAME}. Ensure you have access (e.g., via `huggingface-cli login` if needed).")
421
+ print("Launching Gradio interface...")
422
  iface.launch()