File size: 7,508 Bytes
b2bdd6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# --- START OF FILE architecture.py ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Phi3Config, Phi3ForCausalLM
from typing import Optional, Dict
# --- BUILDING BLOCK 1: VectorMemoryHead (No changes needed here, it inherits dtype correctly) ---
class VectorMemoryHead(nn.Module):
def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_memory_slots = num_memory_slots
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
self.memory_attention = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
self.decoder_attention = nn.MultiheadAttention(
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True,
device=device, dtype=dtype
)
self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
self.decoder_ffn = nn.Sequential(
nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype),
nn.ReLU(),
nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype)
)
def forward(self, memory_input_sequence: torch.Tensor):
batch_size = memory_input_sequence.shape[0]
encoded_vectors = self.encoder(memory_input_sequence)
queries = self.memory_queries.expand(batch_size, -1, -1)
compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
compressed_memory = self.memory_layernorm(compressed_memory + queries)
reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=compressed_memory, value=compressed_memory)
reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
return compressed_memory, reconstructed_vectors
# --- BUILDING BLOCK 2: The Custom Layer (With Iterative Self-Correction) ---
class GCVectorMemoryLayer(nn.Module):
def __init__(self, original_layer: nn.Linear, global_input_dim: int,
memory_dim: int, num_memory_slots: int, memory_num_heads: int,
global_state_storage: Dict):
super().__init__()
self.input_dim = original_layer.in_features
self.output_dim = original_layer.out_features
self.memory_dim = memory_dim
self.global_state_storage = global_state_storage
self.linear = original_layer
device, dtype = self.linear.weight.device, self.linear.weight.dtype
# This part is correct: initialize with the correct dtype
self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
self.memory_head = VectorMemoryHead(
hidden_dim=memory_dim, num_memory_slots=num_memory_slots,
num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device, dtype=dtype
)
self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)
# --- NEW: Parameter for iterative self-correction ---
# This can be changed at inference time to apply the correction multiple times.
# Default is 1 to match training behavior.
self.num_correction_passes: int = 1
self.last_corrected_activation: Optional[torch.Tensor] = None
self.last_additive_correction: Optional[torch.Tensor] = None
self.last_memory_input: Optional[torch.Tensor] = None
self.last_reconstructed_from_memory: Optional[torch.Tensor] = None
def forward(self, x: torch.Tensor):
base_output = self.linear(x)
# If no global state is available or correction is disabled, return base output.
if 'embeds' not in self.global_state_storage or self.num_correction_passes < 1:
return base_output
global_embeds = self.global_state_storage['embeds']
if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
B, S, _ = x.shape
with torch.enable_grad():
# --- 1. Calculate the correction signal ONCE ---
proj_local = self.local_state_proj(x)
proj_global = self.global_state_proj(global_embeds)
memory_input = torch.stack([proj_global, proj_local], dim=2)
memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat)
aggregated_thought_flat = compressed_mem_flat.mean(dim=1)
aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim)
raw_correction = self.correction_head(aggregated_thought)
gate, value = torch.chunk(raw_correction, 2, dim=-1)
# --- 2. Iteratively apply the correction signal ---
corrected_activation = base_output
for _ in range(self.num_correction_passes):
corrected_activation = corrected_activation * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
# During training, store the final activation and the original correction signal
# for loss calculation.
if self.training:
self.last_corrected_activation = corrected_activation
self.last_additive_correction = value # The 'value' is the core additive signal
self.last_memory_input = memory_input_flat
self.last_reconstructed_from_memory = recon_flat
return corrected_activation
# --- BUILDING BLOCK 3: The Full Custom Model ---
class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.global_state_storage = {}
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
self.model.embed_tokens.register_forward_hook(
lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()})
)
try:
original_layer = self.get_submodule(self.target_layer_path)
custom_layer = GCVectorMemoryLayer(
original_layer=original_layer, global_input_dim=config.hidden_size,
memory_dim=64, num_memory_slots=8, memory_num_heads=4,
global_state_storage=self.global_state_storage
)
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
child_name = self.target_layer_path.split('.')[-1]
setattr(self.get_submodule(parent_path), child_name, custom_layer)
print(f"Successfully replaced '{self.target_layer_path}' with GCVectorMemoryLayer.")
except AttributeError:
print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
# --- END OF FILE architecture.py ---
|