import json
from typing import List, Dict, Any, Optional, Tuple
from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity
class SimpleSVGVisualizer:
def __init__(self, config):
self.config = config
def create_visualization_html(
self,
input_tokens: List[str],
output_tokens: List[str],
attention_matrices: List[Dict],
threshold: float = 0.05,
initial_step: int = 0,
selected_token: Optional[int] = None,
selected_type: Optional[str] = None
) -> str:
"""Create a simple SVG visualization without D3."""
# Clean labels
input_labels = [clean_label(token) for token in input_tokens]
output_labels = [clean_label(token) for token in output_tokens]
# Calculate positions
width = self.config.PLOT_WIDTH
height = self.config.PLOT_HEIGHT
margin = 100
input_x = margin
output_x = width - margin
# Create SVG elements
svg_elements = []
# Background
svg_elements.append(f'')
# Title
svg_elements.append(f'Token Attention Flow')
# Calculate vertical positions
input_y_positions = []
output_y_positions = []
if len(input_labels) > 0:
input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1)
input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))]
if len(output_labels) > 0:
output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1)
output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))]
# Draw connections
for j in range(min(initial_step + 1, len(output_labels))):
if j < len(attention_matrices):
for i in range(len(input_labels)):
weight = attention_matrices[j]['input_attention'][i].item()
# Apply filtering
if selected_token is not None:
if selected_type == 'input' and i != selected_token:
continue
elif selected_type == 'output' and j != selected_token:
continue
if weight > threshold:
opacity = scale_weight_to_opacity(weight, threshold)
width_val = scale_weight_to_width(weight)
svg_elements.append(
f''
)
# Draw input nodes
for i, label in enumerate(input_labels):
y = input_y_positions[i]
color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR
svg_elements.append(
f''
)
svg_elements.append(
f'{label}'
)
# Draw output nodes
for j, label in enumerate(output_labels):
y = output_y_positions[j]
color = "yellow" if selected_token == j and selected_type == 'output' else (
self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6"
)
svg_elements.append(
f''
)
svg_elements.append(
f'{label}'
)
# Step info
svg_elements.append(
f''
f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"'
f''
)
# Create HTML
html = f"""
"""
return html