|
import torch |
|
import numpy as np |
|
from typing import List, Dict, Any, Optional |
|
|
|
class AttentionProcessor: |
|
@staticmethod |
|
def process_attention_separate( |
|
attention_data: Dict[str, Any], |
|
input_tokens: List[str], |
|
output_tokens: List[str] |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Process attention with separate normalization for input and output. |
|
This preserves the relative importance within each group. |
|
""" |
|
attentions = attention_data['attentions'] |
|
input_len_for_attention = attention_data['input_len_for_attention'] |
|
output_len = attention_data['output_len'] |
|
|
|
if not attentions: |
|
return [{'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None} for _ in range(output_len)] |
|
|
|
attention_matrices = [] |
|
num_steps = len(attentions) |
|
|
|
if num_steps == 0: |
|
print("Warning: No attention steps found in output.") |
|
return [{'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None} for _ in range(output_len)] |
|
|
|
steps_to_process = min(num_steps, output_len) |
|
|
|
for i in range(steps_to_process): |
|
step_attentions = attentions[i] |
|
input_attention_layers = [] |
|
output_attention_layers = [] |
|
|
|
for layer_idx, layer_attn in enumerate(step_attentions): |
|
try: |
|
|
|
input_indices = slice(1, 1 + input_len_for_attention) |
|
if layer_attn.shape[3] >= input_indices.stop: |
|
|
|
input_attn = layer_attn[0, :, 0, input_indices] |
|
input_attention_layers.append(input_attn) |
|
|
|
|
|
if i > 0: |
|
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i) |
|
if layer_attn.shape[3] >= output_indices.stop: |
|
output_attn = layer_attn[0, :, 0, output_indices] |
|
output_attention_layers.append(output_attn) |
|
else: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
else: |
|
input_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device) |
|
) |
|
if i > 0: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}") |
|
input_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device) |
|
) |
|
if i > 0: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
|
|
|
|
if input_attention_layers: |
|
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1]) |
|
else: |
|
avg_input_attn = torch.zeros(input_len_for_attention) |
|
|
|
avg_output_attn = None |
|
if i > 0 and output_attention_layers: |
|
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1]) |
|
elif i > 0: |
|
avg_output_attn = torch.zeros(i) |
|
|
|
|
|
epsilon = 1e-8 |
|
input_sum = avg_input_attn.sum() + epsilon |
|
normalized_input_attn = avg_input_attn / input_sum |
|
|
|
normalized_output_attn = None |
|
if i > 0 and avg_output_attn is not None: |
|
output_sum = avg_output_attn.sum() + epsilon |
|
normalized_output_attn = avg_output_attn / output_sum |
|
|
|
attention_matrices.append({ |
|
'input_attention': normalized_input_attn.cpu(), |
|
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None, |
|
'raw_input_attention': avg_input_attn.cpu(), |
|
'raw_output_attention': avg_output_attn.cpu() if avg_output_attn is not None else None |
|
}) |
|
|
|
|
|
while len(attention_matrices) < output_len: |
|
attention_matrices.append({ |
|
'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None, |
|
'raw_input_attention': torch.zeros(input_len_for_attention), |
|
'raw_output_attention': None |
|
}) |
|
|
|
return attention_matrices |
|
|
|
@staticmethod |
|
def process_attention_joint( |
|
attention_data: Dict[str, Any], |
|
input_tokens: List[str], |
|
output_tokens: List[str] |
|
) -> List[Dict[str, torch.Tensor]]: |
|
""" |
|
Process attention with joint normalization across input and output. |
|
This preserves the relative importance across all tokens. |
|
""" |
|
attentions = attention_data['attentions'] |
|
input_len_for_attention = attention_data['input_len_for_attention'] |
|
output_len = attention_data['output_len'] |
|
|
|
if not attentions: |
|
return [{'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None} for _ in range(output_len)] |
|
|
|
attention_matrices = [] |
|
num_steps = len(attentions) |
|
|
|
if num_steps == 0: |
|
print("Warning: No attention steps found in output.") |
|
return [{'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None} for _ in range(output_len)] |
|
|
|
steps_to_process = min(num_steps, output_len) |
|
|
|
for i in range(steps_to_process): |
|
step_attentions = attentions[i] |
|
input_attention_layers = [] |
|
output_attention_layers = [] |
|
|
|
for layer_idx, layer_attn in enumerate(step_attentions): |
|
try: |
|
|
|
input_indices = slice(1, 1 + input_len_for_attention) |
|
if layer_attn.shape[3] >= input_indices.stop: |
|
input_attn = layer_attn[0, :, 0, input_indices] |
|
input_attention_layers.append(input_attn) |
|
|
|
|
|
if i > 0: |
|
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i) |
|
if layer_attn.shape[3] >= output_indices.stop: |
|
output_attn = layer_attn[0, :, 0, output_indices] |
|
output_attention_layers.append(output_attn) |
|
else: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
else: |
|
input_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device) |
|
) |
|
if i > 0: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}") |
|
input_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device) |
|
) |
|
if i > 0: |
|
output_attention_layers.append( |
|
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device) |
|
) |
|
|
|
|
|
if input_attention_layers: |
|
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1]) |
|
else: |
|
avg_input_attn = torch.zeros(input_len_for_attention) |
|
|
|
avg_output_attn = None |
|
if i > 0 and output_attention_layers: |
|
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1]) |
|
elif i > 0: |
|
avg_output_attn = torch.zeros(i) |
|
|
|
|
|
epsilon = 1e-8 |
|
if i > 0 and avg_output_attn is not None: |
|
|
|
combined_attn = torch.cat([avg_input_attn, avg_output_attn]) |
|
sum_attn = combined_attn.sum() + epsilon |
|
normalized_combined = combined_attn / sum_attn |
|
normalized_input_attn = normalized_combined[:input_len_for_attention] |
|
normalized_output_attn = normalized_combined[input_len_for_attention:] |
|
else: |
|
|
|
sum_attn = avg_input_attn.sum() + epsilon |
|
normalized_input_attn = avg_input_attn / sum_attn |
|
normalized_output_attn = None |
|
|
|
attention_matrices.append({ |
|
'input_attention': normalized_input_attn.cpu(), |
|
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None |
|
}) |
|
|
|
|
|
while len(attention_matrices) < output_len: |
|
attention_matrices.append({ |
|
'input_attention': torch.zeros(input_len_for_attention), |
|
'output_attention': None |
|
}) |
|
|
|
return attention_matrices |
|
|
|
@staticmethod |
|
def extract_attention_for_step( |
|
attention_data: Dict[str, Any], |
|
step: int, |
|
input_len: int |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Extract attention weights for a specific generation step. |
|
Optimized to only process the needed step. |
|
""" |
|
attentions = attention_data['attentions'] |
|
|
|
if step >= len(attentions): |
|
return { |
|
'input_attention': torch.zeros(input_len), |
|
'output_attention': None |
|
} |
|
|
|
step_attentions = attentions[step] |
|
input_attention_layers = [] |
|
output_attention_layers = [] |
|
|
|
for layer_attn in step_attentions: |
|
|
|
input_indices = slice(1, 1 + input_len) |
|
if layer_attn.shape[3] >= input_indices.stop: |
|
input_attn = layer_attn[0, :, 0, input_indices] |
|
input_attention_layers.append(input_attn) |
|
|
|
|
|
if step > 0: |
|
output_indices = slice(1 + input_len, 1 + input_len + step) |
|
if layer_attn.shape[3] >= output_indices.stop: |
|
output_attn = layer_attn[0, :, 0, output_indices] |
|
output_attention_layers.append(output_attn) |
|
|
|
|
|
if input_attention_layers: |
|
avg_input = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1]) |
|
normalized_input = avg_input / (avg_input.sum() + 1e-8) |
|
else: |
|
normalized_input = torch.zeros(input_len) |
|
|
|
normalized_output = None |
|
if step > 0 and output_attention_layers: |
|
avg_output = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1]) |
|
normalized_output = avg_output / (avg_output.sum() + 1e-8) |
|
|
|
return { |
|
'input_attention': normalized_input.cpu(), |
|
'output_attention': normalized_output.cpu() if normalized_output is not None else None |
|
} |