asadshahab's picture
initial
dd850a7
raw
history blame
13.1 kB
import torch
import numpy as np
from typing import List, Dict, Any, Optional
class AttentionProcessor:
@staticmethod
def process_attention_separate(
attention_data: Dict[str, Any],
input_tokens: List[str],
output_tokens: List[str]
) -> List[Dict[str, torch.Tensor]]:
"""
Process attention with separate normalization for input and output.
This preserves the relative importance within each group.
"""
attentions = attention_data['attentions']
input_len_for_attention = attention_data['input_len_for_attention']
output_len = attention_data['output_len']
if not attentions:
return [{'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None} for _ in range(output_len)]
attention_matrices = []
num_steps = len(attentions)
if num_steps == 0:
print("Warning: No attention steps found in output.")
return [{'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None} for _ in range(output_len)]
steps_to_process = min(num_steps, output_len)
for i in range(steps_to_process):
step_attentions = attentions[i]
input_attention_layers = []
output_attention_layers = []
for layer_idx, layer_attn in enumerate(step_attentions):
try:
# Extract attention to input tokens (skip BOS token at position 0)
input_indices = slice(1, 1 + input_len_for_attention)
if layer_attn.shape[3] >= input_indices.stop:
# Get attention from current token (position 0 in generation) to input
input_attn = layer_attn[0, :, 0, input_indices]
input_attention_layers.append(input_attn)
# Extract attention to previous output tokens
if i > 0:
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
if layer_attn.shape[3] >= output_indices.stop:
output_attn = layer_attn[0, :, 0, output_indices]
output_attention_layers.append(output_attn)
else:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
else:
input_attention_layers.append(
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
)
if i > 0:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
except Exception as e:
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
input_attention_layers.append(
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
)
if i > 0:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
# Average across layers and heads
if input_attention_layers:
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
else:
avg_input_attn = torch.zeros(input_len_for_attention)
avg_output_attn = None
if i > 0 and output_attention_layers:
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
elif i > 0:
avg_output_attn = torch.zeros(i)
# Normalize separately with epsilon for numerical stability
epsilon = 1e-8
input_sum = avg_input_attn.sum() + epsilon
normalized_input_attn = avg_input_attn / input_sum
normalized_output_attn = None
if i > 0 and avg_output_attn is not None:
output_sum = avg_output_attn.sum() + epsilon
normalized_output_attn = avg_output_attn / output_sum
attention_matrices.append({
'input_attention': normalized_input_attn.cpu(),
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None,
'raw_input_attention': avg_input_attn.cpu(), # Keep raw for analysis
'raw_output_attention': avg_output_attn.cpu() if avg_output_attn is not None else None
})
# Fill remaining steps with zeros if needed
while len(attention_matrices) < output_len:
attention_matrices.append({
'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None,
'raw_input_attention': torch.zeros(input_len_for_attention),
'raw_output_attention': None
})
return attention_matrices
@staticmethod
def process_attention_joint(
attention_data: Dict[str, Any],
input_tokens: List[str],
output_tokens: List[str]
) -> List[Dict[str, torch.Tensor]]:
"""
Process attention with joint normalization across input and output.
This preserves the relative importance across all tokens.
"""
attentions = attention_data['attentions']
input_len_for_attention = attention_data['input_len_for_attention']
output_len = attention_data['output_len']
if not attentions:
return [{'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None} for _ in range(output_len)]
attention_matrices = []
num_steps = len(attentions)
if num_steps == 0:
print("Warning: No attention steps found in output.")
return [{'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None} for _ in range(output_len)]
steps_to_process = min(num_steps, output_len)
for i in range(steps_to_process):
step_attentions = attentions[i]
input_attention_layers = []
output_attention_layers = []
for layer_idx, layer_attn in enumerate(step_attentions):
try:
# Extract attention to input tokens
input_indices = slice(1, 1 + input_len_for_attention)
if layer_attn.shape[3] >= input_indices.stop:
input_attn = layer_attn[0, :, 0, input_indices]
input_attention_layers.append(input_attn)
# Extract attention to previous output tokens
if i > 0:
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
if layer_attn.shape[3] >= output_indices.stop:
output_attn = layer_attn[0, :, 0, output_indices]
output_attention_layers.append(output_attn)
else:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
else:
input_attention_layers.append(
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
)
if i > 0:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
except Exception as e:
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
input_attention_layers.append(
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
)
if i > 0:
output_attention_layers.append(
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
)
# Average across layers and heads
if input_attention_layers:
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
else:
avg_input_attn = torch.zeros(input_len_for_attention)
avg_output_attn = None
if i > 0 and output_attention_layers:
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
elif i > 0:
avg_output_attn = torch.zeros(i)
# Joint normalization
epsilon = 1e-8
if i > 0 and avg_output_attn is not None:
# Concatenate and normalize together
combined_attn = torch.cat([avg_input_attn, avg_output_attn])
sum_attn = combined_attn.sum() + epsilon
normalized_combined = combined_attn / sum_attn
normalized_input_attn = normalized_combined[:input_len_for_attention]
normalized_output_attn = normalized_combined[input_len_for_attention:]
else:
# Only input attention available
sum_attn = avg_input_attn.sum() + epsilon
normalized_input_attn = avg_input_attn / sum_attn
normalized_output_attn = None
attention_matrices.append({
'input_attention': normalized_input_attn.cpu(),
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None
})
# Fill remaining steps with zeros if needed
while len(attention_matrices) < output_len:
attention_matrices.append({
'input_attention': torch.zeros(input_len_for_attention),
'output_attention': None
})
return attention_matrices
@staticmethod
def extract_attention_for_step(
attention_data: Dict[str, Any],
step: int,
input_len: int
) -> Dict[str, torch.Tensor]:
"""
Extract attention weights for a specific generation step.
Optimized to only process the needed step.
"""
attentions = attention_data['attentions']
if step >= len(attentions):
return {
'input_attention': torch.zeros(input_len),
'output_attention': None
}
step_attentions = attentions[step]
input_attention_layers = []
output_attention_layers = []
for layer_attn in step_attentions:
# Extract input attention
input_indices = slice(1, 1 + input_len)
if layer_attn.shape[3] >= input_indices.stop:
input_attn = layer_attn[0, :, 0, input_indices]
input_attention_layers.append(input_attn)
# Extract output attention if there are previous outputs
if step > 0:
output_indices = slice(1 + input_len, 1 + input_len + step)
if layer_attn.shape[3] >= output_indices.stop:
output_attn = layer_attn[0, :, 0, output_indices]
output_attention_layers.append(output_attn)
# Average and normalize
if input_attention_layers:
avg_input = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
normalized_input = avg_input / (avg_input.sum() + 1e-8)
else:
normalized_input = torch.zeros(input_len)
normalized_output = None
if step > 0 and output_attention_layers:
avg_output = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
normalized_output = avg_output / (avg_output.sum() + 1e-8)
return {
'input_attention': normalized_input.cpu(),
'output_attention': normalized_output.cpu() if normalized_output is not None else None
}