lucalp commited on
Commit
b074257
·
1 Parent(s): a528449

tiktoken & llama both plotted

Browse files
Files changed (1) hide show
  1. app.py +250 -157
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- import itertools # Import itertools for color cycling
 
 
5
 
 
6
  from bytelatent.data.file_util import get_fs
7
  from bytelatent.generate_patcher import patcher_nocache
8
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
@@ -10,221 +13,311 @@ from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
10
  from bytelatent.args import TrainArgs
11
  from download_blt_weights import main as ensure_present
12
 
13
- # --- Global Setup (Consider loading models outside if necessary) ---
14
- # Kept inside the function for simplicity as before.
15
 
16
- # Define colors for patches (similar to the image style)
17
- # Using colors from a qualitative colormap (e.g., Colorbrewer Set3 or Paired)
18
- PATCH_COLORS = [
19
  "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
20
  "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
21
- ] # Add more if you expect many patches
22
 
 
23
 
24
- def create_highlighted_text_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
25
- """
26
- Generates the data structure needed for gr.HighlightedText based on patches.
27
-
28
- Args:
29
- tokenizer: The BltTokenizer instance.
30
- patch_lengths_tensor: Tensor containing the length of each patch (in tokens).
31
- tokens_tensor: Tensor containing the token IDs for the entire sequence.
32
- colors: A list of color hex codes to cycle through.
33
 
34
- Returns:
35
- A list of tuples for gr.HighlightedText, e.g., [(text, label), ...].
36
- Returns None if input tensors are invalid.
37
- """
38
  if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
39
  return None
40
-
41
  patch_lengths = patch_lengths_tensor.tolist()
42
  all_tokens = tokens_tensor.tolist()
43
  highlighted_data = []
44
  current_token_index = 0
45
- color_cycler = itertools.cycle(colors) # Use itertools to cycle through colors
46
-
47
  for i, length in enumerate(patch_lengths):
48
- if length <= 0: # Skip empty patches if they somehow occur
49
- continue
50
  patch_token_ids = all_tokens[current_token_index : current_token_index + length]
51
- if not patch_token_ids: # Should not happen if length > 0, but good practice
52
- continue
53
-
54
- patch_text = tokenizer.decode(patch_token_ids)
55
- patch_label = f"Patch {i+1}" # Unique label for each patch
56
- patch_color = next(color_cycler) # Get the next color
57
-
58
- # Add to highlighted_data: (text, label_for_coloring)
59
  highlighted_data.append((patch_text, patch_label))
60
  current_token_index += length
61
-
62
- # Check if all tokens were consumed (optional sanity check)
63
  if current_token_index != len(all_tokens):
64
- print(f"Warning: Token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
65
- # Decode any remaining tokens if necessary, though this indicates a logic issue
66
  remaining_tokens = all_tokens[current_token_index:]
67
  if remaining_tokens:
68
- remaining_text = tokenizer.decode(remaining_tokens)
69
- highlighted_data.append((remaining_text, "Remainder")) # Assign a generic label
70
-
71
  return highlighted_data
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def process_text(prompt: str, model_name: str = "blt-1b"):
75
  """
76
- Processes the input prompt using the ByteLatent model and returns
77
- an entropy plot and color-coded text data.
78
 
79
  Args:
80
  prompt: The input text string from the Gradio interface.
81
- model_name: The name of the model to use.
82
 
83
  Returns:
84
  A tuple containing:
85
- - Matplotlib Figure for the entropy plot (or None on error).
86
- - List of tuples for gr.HighlightedText (or None on error/no results).
87
- - Error message string (or None if successful).
 
 
88
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
- # --- Model and Tokenizer Loading ---
 
91
  consolidated_path = os.path.join("hf-weights", model_name)
92
  train_args_path = os.path.join(consolidated_path, "params.json")
93
-
94
- if not os.path.exists(train_args_path):
95
- raise FileNotFoundError(f"Training args not found at {train_args_path}. "
96
- f"Ensure model '{model_name}' is downloaded/available.")
97
-
98
- fs = get_fs(train_args_path)
99
- train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
100
-
101
- tokenizer = train_args.data.tokenizer_args.build()
102
- assert isinstance(tokenizer, BltTokenizer)
103
-
104
- patcher_args = train_args.data.patcher_args.model_copy(deep=True)
105
- patcher_args.realtime_patching = True
106
- device = "cuda" if torch.cuda.is_available() else "cpu"
107
- print(f"Using device: {device}")
108
- patcher_args.patching_device = device
109
- patcher_args.device = device
110
-
111
- print("Loading entropy model and patcher...")
112
  entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
113
- if not os.path.exists(entropy_model_dir):
114
- raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
115
-
116
- patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
117
- patcher = patcher_args.build()
118
- # --- End Loading ---
119
 
120
  # --- Processing ---
121
- prompts = [prompt]
122
- print(f"Processing prompt: '{prompt}'")
123
- results = patcher_nocache(
124
- prompts, tokenizer=tokenizer, patcher=patcher
125
- )
 
 
 
 
 
 
 
126
 
127
  if not results:
128
- print("Processing returned no results.")
129
- return None, None, "Processing completed, but no results were generated."
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- batch_patch_lengths, batch_scores, batch_tokens = results
 
 
 
 
 
 
 
132
 
133
- # Process the first (and only) result in the batch
134
- patch_lengths = batch_patch_lengths[0]
135
- scores = batch_scores[0]
136
- tokens = batch_tokens[0]
137
 
138
- # Decode the full output once for the plot labels (if needed by plot_entropies)
139
- # Note: BltTokenizer might decode directly to bytes, then utf-8. Ensure it handles errors.
140
- try:
141
- # Using the raw tokens tensor for decoding consistency
142
- decoded_output_for_plot = tokenizer.decode(tokens.tolist())
143
- except Exception as decode_err:
144
- print(f"Warning: Error decoding full sequence for plot: {decode_err}")
145
- # Fallback: attempt to decode the original prompt if possible, or use generic labels
146
- decoded_output_for_plot = prompt # Use original prompt as fallback
147
 
148
- # Generate the plot
149
- fig = plot_entropies(
150
- patch_lengths,
151
- scores,
152
- decoded_output_for_plot, # Pass the decoded string for plot labels
153
- threshold=patcher.threshold
154
- )
155
 
156
- # Generate data for HighlightedText
157
- highlighted_data = create_highlighted_text_data(
158
- tokenizer, patch_lengths, tokens, PATCH_COLORS
159
- )
160
 
161
- print("Processing and visualization data generation complete.")
162
- # --- End Processing ---
163
 
164
- return fig, highlighted_data, None # Return plot, highlighted text data, no error
 
165
 
166
- except FileNotFoundError as e:
167
- print(f"Error: {e}")
168
- return None, None, f"Error: {str(e)}" # Return None for plot/text, error message
169
- except Exception as e:
170
- print(f"An unexpected error occurred: {e}")
171
- import traceback
172
- traceback.print_exc()
173
- return None, None, f"An unexpected error occurred: {e}" # Return None for plot/text, error message
174
 
175
- # --- Gradio Interface ---
176
 
177
- # Create the color map for HighlightedText dynamically
178
- # Generate enough patch labels and map them to the cycled colors
179
- MAX_EXPECTED_PATCHES = 50 # Estimate a reasonable maximum
180
- color_map = {
181
- f"Patch {i+1}": color
182
- for i, color in zip(range(MAX_EXPECTED_PATCHES), itertools.cycle(PATCH_COLORS))
183
- }
184
- # Add a color for the potential 'Remainder' label from create_highlighted_text_data
185
- color_map["Remainder"] = "#808080" # Grey for any leftovers
186
-
187
- with gr.Blocks() as iface:
188
- gr.Markdown("# ByteLatent Entropy Visualizer") # Title
189
  gr.Markdown(
190
- "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
191
- "and visualize the token entropies plot and color-coded patches below.<br><br>" # Updated description
192
- "NOTE: this implementation differs slightly by excluding local attention so we limit "
193
- "the characters limit to 512 to avoid any deviation.",
194
- line_breaks=True
195
  )
196
 
197
- with gr.Column():
198
- prompt_input = gr.Textbox(
199
- label="Input Prompt",
200
- value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
201
- placeholder="Enter text here...",
202
- max_length=512,
203
- lines=3
204
- )
205
- submit_button = gr.Button("Generate Visualization") # Update button text
206
-
207
- # Output for error messages or status
208
- status_output = gr.Textbox(label="Status", interactive=False)
209
-
210
- # Output component for the color-coded text
211
- highlighted_output = gr.HighlightedText(
212
- label="Patched Text Visualization",
213
- color_map=color_map,
214
- show_legend=False # Show the patch labels and colors
215
- )
216
-
217
- # Output component for the plot
218
- plot_output = gr.Plot(label="Entropy vs. Token Index (with Patch Threshold)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  # Define the action for the button click
221
  submit_button.click(
222
  fn=process_text,
223
  inputs=prompt_input,
224
- outputs=[plot_output, highlighted_output, status_output] # Order matters!
 
 
 
 
 
 
 
225
  )
226
 
227
  # --- Launch the Gradio App ---
228
  if __name__ == "__main__":
229
- ensure_present(["blt-1b"]) # Ensure model is present before launching
 
 
230
  iface.launch()
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import itertools # For color cycling
5
+ import tiktoken # For GPT-4 tokenizer
6
+ from transformers import AutoTokenizer, AutoModel # For Llama3 tokenizer
7
 
8
+ # Bytelatent imports (assuming they are in the python path)
9
  from bytelatent.data.file_util import get_fs
10
  from bytelatent.generate_patcher import patcher_nocache
11
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
 
13
  from bytelatent.args import TrainArgs
14
  from download_blt_weights import main as ensure_present
15
 
16
+ # --- Global Setup ---
 
17
 
18
+ # Define colors for patches/tokens
19
+ VIZ_COLORS = [
 
20
  "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
21
  "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
22
+ ] # Add more if you expect many segments
23
 
24
+ LLAMA3_MODEL_NAME = "meta-llama/Meta-Llama-3-8B" # Or choose another variant like Instruct
25
 
26
+ # --- Helper Functions ---
 
 
 
 
 
 
 
 
27
 
28
+ def create_bytelatent_highlight_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
29
+ """Generates data for gr.HighlightedText based on bytelatent patches."""
30
+ # (Keep the function from the previous version - no changes needed)
 
31
  if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
32
  return None
 
33
  patch_lengths = patch_lengths_tensor.tolist()
34
  all_tokens = tokens_tensor.tolist()
35
  highlighted_data = []
36
  current_token_index = 0
37
+ color_cycler = itertools.cycle(colors)
 
38
  for i, length in enumerate(patch_lengths):
39
+ if length <= 0: continue
 
40
  patch_token_ids = all_tokens[current_token_index : current_token_index + length]
41
+ if not patch_token_ids: continue
42
+ try: patch_text = tokenizer.decode(patch_token_ids)
43
+ except Exception as decode_err:
44
+ print(f"Warning: Bytelatent patch decoding failed: {decode_err}")
45
+ patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
46
+ patch_label = f"BL Patch {i+1}"
 
 
47
  highlighted_data.append((patch_text, patch_label))
48
  current_token_index += length
 
 
49
  if current_token_index != len(all_tokens):
50
+ print(f"Warning: Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
 
51
  remaining_tokens = all_tokens[current_token_index:]
52
  if remaining_tokens:
53
+ try: remaining_text = tokenizer.decode(remaining_tokens)
54
+ except Exception: remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
55
+ highlighted_data.append((remaining_text, "BL Remainder"))
56
  return highlighted_data
57
 
58
 
59
+ def create_tiktoken_highlight_data(prompt, colors):
60
+ """Generates data for gr.HighlightedText based on tiktoken (gpt-4) tokens."""
61
+ # (Keep the function from the previous version - no changes needed)
62
+ try:
63
+ enc = tiktoken.get_encoding("cl100k_base")
64
+ tiktoken_ids = enc.encode(prompt)
65
+ highlighted_data = []
66
+ color_cycler = itertools.cycle(colors)
67
+ for i, token_id in enumerate(tiktoken_ids):
68
+ try: token_text = enc.decode([token_id])
69
+ except UnicodeDecodeError:
70
+ try:
71
+ token_bytes = enc.decode_single_token_bytes(token_id)
72
+ token_text = f"[Bytes: {token_bytes.hex()}]"
73
+ except Exception: token_text = "[Decode Error]"
74
+ except Exception as e:
75
+ print(f"Unexpected tiktoken decode error: {e}")
76
+ token_text = "[Decode Error]"
77
+ token_label = f"GPT4 Tk {i+1}"
78
+ highlighted_data.append((token_text, token_label))
79
+ print(f"Tiktoken processing complete. Found {len(tiktoken_ids)} tokens.")
80
+ return highlighted_data
81
+ except ImportError:
82
+ print("Error: tiktoken library not found. Please install it: pip install tiktoken")
83
+ return [("tiktoken library not installed.", "Error")]
84
+ except Exception as tiktoken_err:
85
+ print(f"Error during tiktoken processing: {tiktoken_err}")
86
+ return [(f"Error processing with tiktoken: {str(tiktoken_err)}", "Error")]
87
+
88
+
89
+ def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
90
+ """Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
91
+ try:
92
+ # Load Llama 3 tokenizer from Hugging Face Hub
93
+ # This might download the tokenizer files on the first run
94
+ # May require `huggingface-cli login` if model is private or gated
95
+ print(f"Loading Llama 3 tokenizer: {model_name}")
96
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
97
+ print("Llama 3 tokenizer loaded.")
98
+
99
+ # Encode the prompt
100
+ llama_token_ids = tokenizer.encode(prompt)
101
+
102
+ highlighted_data = []
103
+ color_cycler = itertools.cycle(colors)
104
+
105
+ for i, token_id in enumerate(llama_token_ids):
106
+ try:
107
+ # Decode individual token. Llama/SentencePiece tokenizers usually handle this well.
108
+ token_text = tokenizer.decode([token_id])
109
+ # Special case: Handle potential leading space added by sentencepiece during decode
110
+ # if token_text.startswith(' '): # Check if this improves visualization
111
+ # token_text = token_text[1:] # Remove leading space visual artifact? Test this.
112
+ except Exception as e:
113
+ print(f"Unexpected Llama 3 decode error for token {token_id}: {e}")
114
+ token_text = "[Decode Error]"
115
+
116
+ token_label = f"Llama3 Tk {i+1}" # Clearer label prefix
117
+ highlighted_data.append((token_text, token_label))
118
+
119
+ print(f"Llama 3 processing complete. Found {len(llama_token_ids)} tokens.")
120
+ return highlighted_data
121
+
122
+ except ImportError:
123
+ print("Error: transformers or sentencepiece library not found. Please install them: pip install transformers sentencepiece")
124
+ return [("transformers/sentencepiece library not installed.", "Error")]
125
+ except OSError as e:
126
+ # Handle errors like model not found, network issues, authentication needed
127
+ print(f"Error loading Llama 3 tokenizer '{model_name}': {e}")
128
+ if "authentication" in str(e).lower():
129
+ return [(f"Authentication required for Llama 3 tokenizer '{model_name}'. Use `huggingface-cli login`.", "Error")]
130
+ else:
131
+ return [(f"Could not load Llama 3 tokenizer '{model_name}'. Check model name and network. Error: {e}", "Error")]
132
+ except Exception as llama_err:
133
+ print(f"Error during Llama 3 processing: {llama_err}")
134
+ import traceback
135
+ traceback.print_exc() # Print full traceback for debugging
136
+ return [(f"Error processing with Llama 3: {str(llama_err)}", "Error")]
137
+
138
+
139
+ # --- Main Processing Function ---
140
+
141
  def process_text(prompt: str, model_name: str = "blt-1b"):
142
  """
143
+ Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
144
+ returning visualizations and status.
145
 
146
  Args:
147
  prompt: The input text string from the Gradio interface.
148
+ model_name: The name of the bytelatent model to use.
149
 
150
  Returns:
151
  A tuple containing:
152
+ - Matplotlib Figure for the entropy plot (or None).
153
+ - List of tuples for bytelatent gr.HighlightedText (or None).
154
+ - List of tuples for tiktoken gr.HighlightedText (or None).
155
+ - List of tuples for Llama 3 gr.HighlightedText (or None).
156
+ - Status/Error message string.
157
  """
158
+ fig = None
159
+ bl_highlighted_data = None
160
+ tk_highlighted_data = None
161
+ llama_highlighted_data = None
162
+ status_message = "Starting processing..."
163
+
164
+ # --- 1. Tiktoken Processing (Independent) ---
165
+ status_message += "\nProcessing with Tiktoken (gpt-4)..."
166
+ tk_highlighted_data = create_tiktoken_highlight_data(prompt, VIZ_COLORS)
167
+ if tk_highlighted_data and tk_highlighted_data[0][1] == "Error":
168
+ status_message += f"\nTiktoken Error: {tk_highlighted_data[0][0]}"
169
+ else:
170
+ status_message += "\nTiktoken processing successful."
171
+
172
+ # --- 2. Llama 3 Processing (Independent) ---
173
+ status_message += "\nProcessing with Llama 3 tokenizer..."
174
+ llama_highlighted_data = create_llama3_highlight_data(prompt, VIZ_COLORS)
175
+ if llama_highlighted_data and llama_highlighted_data[0][1] == "Error":
176
+ status_message += f"\nLlama 3 Error: {llama_highlighted_data[0][0]}"
177
+ else:
178
+ status_message += "\nLlama 3 processing successful."
179
+
180
+ # --- 3. Bytelatent Processing ---
181
  try:
182
+ status_message += f"\nLoading entropy model for '{model_name}'..."
183
+ # (Bytelatent loading code remains the same as previous version)
184
  consolidated_path = os.path.join("hf-weights", model_name)
185
  train_args_path = os.path.join(consolidated_path, "params.json")
186
+ if not os.path.exists(train_args_path): raise FileNotFoundError(f"Bytelatent training args not found at {train_args_path}.")
187
+ fs = get_fs(train_args_path); train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
188
+ bl_tokenizer = train_args.data.tokenizer_args.build(); assert isinstance(bl_tokenizer, BltTokenizer)
189
+ patcher_args = train_args.data.patcher_args.model_copy(deep=True); patcher_args.realtime_patching = True
190
+ device = "cuda" if torch.cuda.is_available() else "cpu"; print(f"Using Bytelatent device: {device}")
191
+ patcher_args.patching_device = device; patcher_args.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
193
+ if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Bytelatent entropy model directory not found at {entropy_model_dir}.")
194
+ patcher_args.entropy_model_checkpoint_dir = entropy_model_dir; bl_patcher = patcher_args.build()
195
+ status_message += "\nBytelatent model loaded."
 
 
 
196
 
197
  # --- Processing ---
198
+ status_message += "\nRunning Bytelatent patching..."
199
+ print(f"Processing prompt with Bytelatent: '{prompt}'")
200
+ # Limit prompt length for bytelatent if necessary
201
+ prompt_bytes = prompt.encode('utf-8')
202
+ if len(prompt_bytes) > 512:
203
+ print(f"Warning: Prompt exceeds 512 bytes ({len(prompt_bytes)}). Truncating for Bytelatent.")
204
+ prompt_bl = prompt_bytes[:512].decode('utf-8', errors='ignore')
205
+ status_message += "\nWarning: Prompt truncated to 512 bytes for Bytelatent."
206
+ else:
207
+ prompt_bl = prompt
208
+
209
+ results = patcher_nocache([prompt_bl], tokenizer=bl_tokenizer, patcher=bl_patcher)
210
 
211
  if not results:
212
+ print("Bytelatent processing returned no results.")
213
+ status_message += "\nBytelatent Warning: Processing completed, but no results were generated."
214
+ else:
215
+ batch_patch_lengths, batch_scores, batch_tokens = results
216
+ patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
217
+ # --- Visualization Data Generation ---
218
+ try: decoded_output_for_plot = bl_tokenizer.decode(tokens.tolist())
219
+ except Exception as decode_err:
220
+ print(f"Warning: Error decoding full sequence for plot: {decode_err}")
221
+ decoded_output_for_plot = prompt_bl # Use truncated prompt for plot if decode fails
222
+ fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=bl_patcher.threshold)
223
+ bl_highlighted_data = create_bytelatent_highlight_data(bl_tokenizer, patch_lengths, tokens, VIZ_COLORS)
224
+ status_message += "\nBytelatent processing and visualization successful."
225
+ print("Bytelatent processing and decoding complete.")
226
 
227
+ except FileNotFoundError as e:
228
+ print(f"Bytelatent Error: {e}")
229
+ status_message += f"\nBytelatent FileNotFoundError: {str(e)}"
230
+ except Exception as e:
231
+ print(f"An unexpected Bytelatent error occurred: {e}")
232
+ import traceback
233
+ traceback.print_exc()
234
+ status_message += f"\nBytelatent Unexpected Error: {str(e)}"
235
 
236
+ # Return all generated data and the final status message
237
+ return fig, bl_highlighted_data, tk_highlighted_data, llama_highlighted_data, status_message
 
 
238
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # --- Gradio Interface ---
 
 
 
 
 
 
241
 
242
+ # Create color maps for HighlightedText dynamically
243
+ MAX_EXPECTED_SEGMENTS = 1000 # Increase max expected segments further
244
+ common_error_map = {"Error": "#FF0000"} # Red for errors
 
245
 
246
+ bytelatent_color_map = {f"BL Patch {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
247
+ bytelatent_color_map["BL Remainder"] = "#808080"; bytelatent_color_map.update(common_error_map)
248
 
249
+ tiktoken_color_map = {f"GPT4 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
250
+ tiktoken_color_map.update(common_error_map)
251
 
252
+ llama3_color_map = {f"Llama3 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
253
+ llama3_color_map.update(common_error_map)
 
 
 
 
 
 
254
 
 
255
 
256
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
257
+ gr.Markdown("# BLT's Entropy Patcher Visualisation") # Updated Title
 
 
 
 
 
 
 
 
 
 
258
  gr.Markdown(
259
+ "Enter text to visualize its segmentation according to different tokenizers:\n"
260
+ "1. **BLT:** Entropy plot and text segmented by dynamic patches (Input limited to 512 bytes).\n"
261
+ "2. **Tiktoken (GPT-4):** Text segmented by `cl100k_base` tokens.\n"
262
+ "3. **Llama 3:** Text segmented by the `meta-llama/Meta-Llama-3-8B` tokenizer."
 
263
  )
264
 
265
+ with gr.Row():
266
+ with gr.Column(scale=1): # Input Column
267
+ prompt_input = gr.Textbox(
268
+ label="Input Prompt",
269
+ value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
270
+ placeholder="Enter text here...",
271
+ max_length=2048, # Allow even longer input, Bytelatent will truncate
272
+ lines=5,
273
+ info="Processing is limited to the first 512 bytes of the input."
274
+ )
275
+ submit_button = gr.Button("Generate Visualizations", variant="primary")
276
+ status_output = gr.Textbox(label="Processing Status", interactive=False, lines=5)
277
+
278
+ with gr.Column(scale=2): # Output Column
279
+ gr.Markdown("### BLT's Entropy Patcher Output (`100m`)")
280
+ highlighted_output_bl = gr.HighlightedText(
281
+ label="Bytelatent Patched Text",
282
+ color_map=bytelatent_color_map,
283
+ show_legend=False, # Legend can get very long, disable for compactness
284
+ show_inline_category=False,
285
+ )
286
+ plot_output = gr.Plot(label="Bytelatent Entropy vs. Token Index")
287
+
288
+ gr.Markdown("### Tiktoken Output (`cl100k_base` for GPT-4)")
289
+ highlighted_output_tk = gr.HighlightedText(
290
+ label="Tiktoken Segmented Text",
291
+ color_map=tiktoken_color_map,
292
+ show_legend=False,
293
+ show_inline_category=False,
294
+ )
295
+
296
+ gr.Markdown(f"### Llama 3 Output (`{LLAMA3_MODEL_NAME}`)")
297
+ highlighted_output_llama = gr.HighlightedText(
298
+ label="Llama 3 Segmented Text",
299
+ color_map=llama3_color_map,
300
+ show_legend=False,
301
+ show_inline_category=False,
302
+ )
303
 
304
  # Define the action for the button click
305
  submit_button.click(
306
  fn=process_text,
307
  inputs=prompt_input,
308
+ # Ensure order matches the 5 return values of process_text
309
+ outputs=[
310
+ plot_output,
311
+ highlighted_output_bl,
312
+ highlighted_output_tk,
313
+ highlighted_output_llama,
314
+ status_output
315
+ ]
316
  )
317
 
318
  # --- Launch the Gradio App ---
319
  if __name__ == "__main__":
320
+ print("Please ensure 'tiktoken', 'transformers', and 'sentencepiece' are installed (`pip install tiktoken transformers sentencepiece`)")
321
+ print(f"Attempting to use Llama 3 Tokenizer: {LLAMA3_MODEL_NAME}. Ensure you have access (e.g., via `huggingface-cli login` if needed).")
322
+ ensure_present(["blt-1b"]) # Ensure bytelatent model is present
323
  iface.launch()