import json from typing import List, Dict, Any, Optional, Tuple from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity class SimpleSVGVisualizer: def __init__(self, config): self.config = config def create_visualization_html( self, input_tokens: List[str], output_tokens: List[str], attention_matrices: List[Dict], threshold: float = 0.05, initial_step: int = 0, selected_token: Optional[int] = None, selected_type: Optional[str] = None ) -> str: """Create a simple SVG visualization without D3.""" # Clean labels input_labels = [clean_label(token) for token in input_tokens] output_labels = [clean_label(token) for token in output_tokens] # Calculate positions width = self.config.PLOT_WIDTH height = self.config.PLOT_HEIGHT margin = 100 input_x = margin output_x = width - margin # Create SVG elements svg_elements = [] # Background svg_elements.append(f'') # Title svg_elements.append(f'Token Attention Flow') # Calculate vertical positions input_y_positions = [] output_y_positions = [] if len(input_labels) > 0: input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1) input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))] if len(output_labels) > 0: output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1) output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))] # Draw connections for j in range(min(initial_step + 1, len(output_labels))): if j < len(attention_matrices): for i in range(len(input_labels)): weight = attention_matrices[j]['input_attention'][i].item() # Apply filtering if selected_token is not None: if selected_type == 'input' and i != selected_token: continue elif selected_type == 'output' and j != selected_token: continue if weight > threshold: opacity = scale_weight_to_opacity(weight, threshold) width_val = scale_weight_to_width(weight) svg_elements.append( f'' ) # Draw input nodes for i, label in enumerate(input_labels): y = input_y_positions[i] color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR svg_elements.append( f'' ) svg_elements.append( f'{label}' ) # Draw output nodes for j, label in enumerate(output_labels): y = output_y_positions[j] color = "yellow" if selected_token == j and selected_type == 'output' else ( self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6" ) svg_elements.append( f'' ) svg_elements.append( f'{label}' ) # Step info svg_elements.append( f'' f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"' f'' ) # Create HTML html = f"""
{''.join(svg_elements)}
""" return html